Exemple #1
0
class Trainer:
    out: Path
    out_prefix: str
    data_loaders: DataLoaderCollection
    metrics: MetricsCollection
    train_rate: float = 0.1

    max_epochs: int
    start_epoch: int
    save_checkpoint_epochs: int

    loss_metric: Loss
    optimizer: Optimizer
    model: Module
    device: torch.device
    half: bool
    ddp: bool
    test_no_grad: bool

    checkpoint: Checkpoint
    best_checkpoint: Checkpoint

    _training_model: Module
    _training_optimizer: Optimizer

    def __init__(self,
                 out_prefix: str,
                 data_loaders: DataLoaderCollection,
                 metrics: MetricsCollection,
                 loss_function: Callable[[torch.Tensor, torch.Tensor],
                                         torch.Tensor],
                 model: Module,
                 optimizer: Optimizer,
                 out: Path,
                 train_rate: float,
                 max_epochs: int,
                 save_checkpoint_epochs: int,
                 device: torch.device,
                 half: bool,
                 ddp: bool,
                 test_no_grad: bool = True):
        self.out_prefix = out_prefix
        self.data_loaders = data_loaders
        self.metrics = deepcopy(metrics)
        self.loss_metric = Loss(loss_function)

        self.metrics.train_metrics.insert(0, self.loss_metric)
        self.metrics.test_metrics.insert(0, Loss(
            loss_function))  # test loss metric should calculate independently.

        self.model = model
        self.optimizer = optimizer

        self.out = out
        self.train_rate = train_rate
        self.max_epochs = max_epochs
        self.save_checkpoint_epochs = save_checkpoint_epochs
        self.device = device
        self.half = half
        self.ddp = ddp
        self.test_no_grad = test_no_grad
        self.start_epoch = 0

        self.checkpoint = Checkpoint()
        self.best_checkpoint = Checkpoint()
        self.best_checkpoint.indicate = inf

        self._training_model = self.model
        self._training_optimizer = self.optimizer
        self.setup_training_model_and_optimizer()

    def train(self):
        self.out.mkdir(parents=True, exist_ok=True)

        for epoch in range(self.start_epoch, self.max_epochs):
            self.checkpoint.set_epoch(epoch)
            self.train_epoch(epoch)
            self.test(self.test_no_grad)

            torch.cuda.empty_cache()

            self.checkpoint.add_metrics(self.metrics)
            self.checkpoint.set_indicate(self.calc_indicate(self.train_rate))

            if self.checkpoint.indicate <= self.best_checkpoint.indicate:
                self.checkpoint.set_states(self.model, self.optimizer)
                self.best_checkpoint = self.checkpoint.copy()
                self.save_model(self.out / f'{self.out_prefix}_backup_best.pt',
                                self.best_checkpoint)

            if (epoch + 1) % self.save_checkpoint_epochs == 0:
                self.checkpoint.set_states(self.model, self.optimizer)
                self.save_model(
                    self.out / f'{self.out_prefix}_backup_{epoch + 1}.pt',
                    self.checkpoint)
            print(
                f"\r"
                f"epoch {epoch + 1}/{self.max_epochs} / "
                f"train({', '.join([str(metric) for metric in self.metrics.train_metrics])}) / "
                f"test({', '.join([str(metric) for metric in self.metrics.test_metrics])}) / "
                f"indicate(current={self.checkpoint.indicate:.6f}, best={self.best_checkpoint.indicate:.6f})"
            )
        self.save_model(self.out / f'{self.out_prefix}_best.pt',
                        self.best_checkpoint)

        self.checkpoint.set_states(self.model, self.optimizer)
        self.save_model(self.out / f'{self.out_prefix}_last.pt',
                        self.checkpoint)

    def setup_training_model_and_optimizer(self):
        self.model.to(self.device)
        if self.half:
            from apex import amp
            # Initialization
            opt_level = 'O1'
            self._training_model, self._training_optimizer = \
                amp.initialize(self._training_model, self._training_optimizer, opt_level=opt_level, verbosity=0)

        if self.ddp:
            # try port number incrementally
            port = 9999
            while True:
                try:
                    dist.init_process_group(
                        backend='nccl',
                        init_method=f'tcp://127.0.0.1:{port}',
                        world_size=1,
                        rank=0)

                except RuntimeError as e:
                    if str(e) == "Address already in use":
                        port += 1
                        continue
                    else:
                        raise e
                break
            self._training_model = DistributedDataParallel(
                self._training_model, find_unused_parameters=True)

    def save_model(self, path: Path, checkpoint: Checkpoint):
        real_path = path

        state_dict = checkpoint.state_dict()
        torch.save(state_dict, real_path)

    def calc_indicate(self, train_rate) -> float:
        test_rate = 1 - train_rate
        test_loss = self.metrics.test_metrics[0]
        return self.loss_metric.get_value()**train_rate * test_loss.get_value(
        )**test_rate

    def load_checkpoint(self, state_dict: dict):
        self.checkpoint.load(state_dict)
        self.start_epoch = self.checkpoint.epoch + 1
        self.model.load_state_dict(self.checkpoint.model_state)
        self.optimizer.load_state_dict(self.checkpoint.optimizer_state)

    def train_epoch(self, epoch: int):
        self._training_model.train()
        self.reset_all_metrics_states(self.metrics.train_metrics)

        for i, (data, target) in enumerate(self.data_loaders.train_loader):
            data = data.to(self.device)
            target = target.to(self.device)

            pred = self._training_model(data)
            evaluated = self.evaluate_metrics(pred, target,
                                              self.metrics.train_metrics)
            loss = self.loss_metric.get_tensor()
            if self.half:
                from apex import amp
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            self._training_optimizer.step()
            self._training_optimizer.zero_grad()

            mem_info = ""
            if self.device.type != "cpu":
                torch.cuda.synchronize()
                mem_info = f" / mem(allocated={torch.cuda.memory_allocated()}, cached={torch.cuda.memory_cached()})"

            print("\r"
                  f"epoch {epoch + 1}/{self.max_epochs} / "
                  f"batch {i + 1}/{len(self.data_loaders.train_loader)} / "
                  f"train ({self.get_evaluated_metrics_info(evaluated)})" +
                  mem_info,
                  end="")

    def test(self,
             no_grad: bool = True,
             loader: DataLoader = None,
             metrics: List[MetricBase] = None):
        self._training_model.eval()

        grad_policy = torch.no_grad if no_grad else torch.enable_grad
        loader = loader or self.data_loaders.test_loader
        metrics = metrics or self.metrics.test_metrics

        self.reset_all_metrics_states(metrics)

        with grad_policy():
            for i, (data, target) in enumerate(loader):
                data = data.to(self.device)
                target = target.to(self.device)

                pred = self._training_model(data)
                self.evaluate_metrics(pred, target, metrics)

    def reset_all_metrics_states(self, metrics: List[MetricBase]):
        for metric in metrics:
            metric.reset_states()

    def evaluate_metrics(self, pred: torch.Tensor, target: torch.Tensor,
                         metrics: List[MetricBase]) -> Dict[str, float]:
        # The loss should calculate the gradient if gradient is enabled.
        loss_metric = metrics[0]
        result = {loss_metric.get_name(): loss_metric.evaluate(pred, target)}

        # Other metrics shouldn't calculate gradients so let them are detached.
        pred = pred.detach()
        target = target.detach()

        for metric in metrics[1:]:
            result[metric.get_name()] = metric.evaluate(pred, target)

        return result

    def get_evaluated_metrics_info(self, evaluated_infos: Dict[str,
                                                               float]) -> str:
        return ', '.join(
            [f'{name}={info:.6f}' for name, info in evaluated_infos.items()])
Exemple #2
0
def train():
    config_file = "configs/train_daily_dialog_topic_config.json"
    config = Config.from_json_file(config_file)

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", config.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(config))

    # Initialize distributed training if needed
    config.distributed = (config.local_rank != -1)
    if config.distributed:
        torch.cuda.set_device(config.local_rank)
        config.device = torch.device("cuda", config.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint)
    model_class = GPT2DoubleHeadsModel if "gpt2" in config.model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(config.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(config.device)
    optimizer = OpenAIAdam(model.parameters(), lr=config.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if config.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=config.fp16)
    if config.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[config.local_rank],
                                        output_device=config.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        config, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = tuple(
            input_tensor.to(config.device) for input_tensor in batch)
        lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels,
                                 token_type_ids, token_emotion_ids,
                                 token_action_ids)
        loss = (lm_loss * config.lm_coef +
                mc_loss * config.mc_coef) / config.gradient_accumulation_steps
        if config.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           config.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm)
        if engine.state.iteration % config.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(config.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = batch
            #logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids,
                                  mc_token_ids,
                                  token_type_ids=token_type_ids,
                                  token_emotion_ids=token_emotion_ids,
                                  token_action_ids=token_action_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[
                1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if config.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if config.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if config.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, config.lr),
                                 (config.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], config),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], config)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if config.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=config.log_dir)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(config,
                   tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=config.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if config.local_rank in [-1, 0] and config.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Exemple #3
0
    def lops(self, rank, data_dict, optimizer, scheduler, percent_thresh=0.5):
        print("LOPS running..")
        print("Before LOPS: Input Ids", data_dict["input_ids"].shape)
        print("Before LOPS: Mask", data_dict["attention_masks"].shape)
        print("Before LOPS: Labels", data_dict["labels"].shape)
        model = LOTClassModel.from_pretrained(self.pretrained_lm,
                                              output_attentions=False,
                                              output_hidden_states=False,
                                              num_labels=self.num_class)
        model = model.to(rank)
        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
        model.train()
        y_pseudo = data_dict["all_target_pred"].numpy()
        dataset = TensorDataset(data_dict["input_ids"], data_dict["attention_masks"], data_dict["labels"])
        sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=rank)
        train_dataloader = DataLoader(dataset, sampler=sampler, batch_size=self.train_batch_size, shuffle=False)

        sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=rank, shuffle=False)
        prediction_dataloader = DataLoader(dataset,
                                           sampler=sampler,
                                           batch_size=self.eval_batch_size,
                                           shuffle=False
                                           )

        inds_map = {}
        for i, j in enumerate(y_pseudo):
            try:
                inds_map[j].append(i)
            except:
                inds_map[j] = [i]

        thresh_map = dict(Counter(y_pseudo))
        print("Counts of pseudo-labels ", thresh_map, flush=True)
        for i in thresh_map:
            thresh_map[i] = int(thresh_map[i] * percent_thresh)

        print("Threshold map ", thresh_map, flush=True)

        filter_flag_map = {}
        train_inds_map = {}
        non_train_inds_map = {}
        for i in thresh_map:
            filter_flag_map[i] = False
            train_inds_map[i] = []
            non_train_inds_map[i] = []

        total_train_loss = 0
        wrap_train_dataset_loader = tqdm(train_dataloader) if rank == 0 else train_dataloader
        model.zero_grad()
        try:
            for j, batch in enumerate(wrap_train_dataset_loader):
                input_ids = batch[0].to(rank)
                input_mask = batch[1].to(rank)
                target_dist = batch[2].to(rank)
                logits = model(input_ids,
                               pred_mode="classification",
                               token_type_ids=None,
                               attention_mask=input_mask)
                logits = logits[:, 0, :]
                preds = nn.LogSoftmax(dim=-1)(logits)
                loss = self.st_loss(preds.view(-1, self.num_class),
                                    target_dist.view(-1, self.num_class)) / self.accum_steps
                total_train_loss += loss.item()
                loss.backward()
                if (j + 1) % self.accum_steps == 0:
                    # Clip the norm of the gradients to 1.0.
                    nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()

            avg_train_loss = torch.tensor([total_train_loss / len(wrap_train_dataset_loader) * self.accum_steps]).to(
                rank)
            gather_list = [torch.ones_like(avg_train_loss) for _ in range(self.world_size)]
            dist.all_gather(gather_list, avg_train_loss)
            avg_train_loss = torch.tensor(gather_list)
            if rank == 0:
                print(f"lr: {optimizer.param_groups[0]['lr']:.4g}")
                print(f"Average training loss: {avg_train_loss.mean().item()}")
        except RuntimeError as err:
            self.cuda_mem_error(err, "train", rank)

        input_ids, input_mask, preds = self.inference(model, prediction_dataloader, rank, return_type="data")
        gather_input_ids = [torch.ones_like(input_ids) for _ in range(self.world_size)]
        gather_input_mask = [torch.ones_like(input_mask) for _ in range(self.world_size)]
        gather_preds = [torch.ones_like(preds) for _ in range(self.world_size)]
        dist.all_gather(gather_input_ids, input_ids)
        dist.all_gather(gather_input_mask, input_mask)
        dist.all_gather(gather_preds, preds)
        all_preds = torch.cat(gather_preds, dim=0).cpu()
        pred_inds = all_preds.argmax(dim=-1)

        count = 0
        for i in filter_flag_map:
            if not filter_flag_map[i]:
                train_inds, non_train_inds = self.compute_train_non_train_inds(pred_inds, y_pseudo, inds_map, i)
                train_inds_map[i] = train_inds
                non_train_inds_map[i] = non_train_inds
                if len(train_inds) >= thresh_map[i]:
                    filter_flag_map[i] = True
                    count += 1
            else:
                count += 1

        print("Number of labels reached 50 percent threshold", count)
        for i in filter_flag_map:
            if not filter_flag_map[i]:
                print("For label ", i, " Number expected ", thresh_map[i], " Found ", len(train_inds_map[i]))

        for i in filter_flag_map:
            if not filter_flag_map[i]:
                print("Resetting train, non-train inds for label ", i)
                train_inds_map[i] = inds_map[i]
                non_train_inds_map[i] = []

        ret_input_ids = []
        ret_masks = []
        ret_labels = []
        for lbl in train_inds_map:
            for loop_ind in train_inds_map[lbl]:
                ret_input_ids.append(data_dict["input_ids"][loop_ind].unsqueeze(0))
                ret_masks.append(data_dict["attention_masks"][loop_ind].unsqueeze(0))
                ret_labels.append(data_dict["labels"][loop_ind].unsqueeze(0))

        all_input_ids = torch.cat(ret_input_ids, dim=0)
        all_input_mask = torch.cat(ret_masks, dim=0)
        all_labels = torch.cat(ret_labels, dim=0)
        print("After LOPS: Input Ids", all_input_ids.shape)
        print("After LOPS: Mask", all_input_mask.shape)
        print("After LOPS: Labels", all_labels.shape)

        self_train_dict = {"input_ids": all_input_ids, "attention_masks": all_input_mask, "labels": all_labels}
        return self_train_dict
    def setup_experiment(self, config):
        """
        Configure the experiment for training

        :param config: Dictionary containing the configuration parameters

            - distributed: Whether or not to use Pytorch Distributed training
            - backend: Pytorch Distributed backend ("nccl", "gloo")
                    Default: nccl
            - world_size: Total number of processes participating
            - rank: Rank of the current process
            - data: Dataset path
            - train_dir: Dataset training data relative path
            - batch_size: Training batch size
            - val_dir: Dataset validation data relative path
            - val_batch_size: Validation batch size
            - workers: how many data loading processes to use
            - train_loader_drop_last: Whether to skip last batch if it is
                                      smaller than the batch size
            - num_classes: Limit the dataset size to the given number of classes
            - model_class: Model class. Must inherit from "torch.nn.Module"
            - model_args: model model class arguments passed to the constructor
            - init_batch_norm: Whether or not to Initialize running batch norm
                               mean to 0.
            - optimizer_class: Optimizer class.
                               Must inherit from "torch.optim.Optimizer"
            - optimizer_args: Optimizer class class arguments passed to the
                              constructor
            - batch_norm_weight_decay: Whether or not to apply weight decay to
                                       batch norm modules parameters
                                       See https://arxiv.org/abs/1807.11205
            - bias_weight_decay: Whether or not to apply weight decay to
                                       bias parameters
            - lr_scheduler_class: Learning rate scheduler class.
                                 Must inherit from "_LRScheduler"
            - lr_scheduler_args: Learning rate scheduler class class arguments
                                 passed to the constructor
            - loss_function: Loss function. See "torch.nn.functional"
            - local_dir: Results path
            - logdir: Directory generated by Ray Tune for this Trial
            - epochs: Number of epochs to train
            - batches_in_epoch: Number of batches per epoch.
                                Useful for debugging
            - log_timestep_freq: Configures mixins and subclasses that log every
                                 timestep to only log every nth timestep (in
                                 addition to the final timestep of each epoch).
                                 Set to 0 to log only at the end of each epoch.
            - progress: Show progress during training
            - name: Experiment name. Used as logger name
            - log_level: Python Logging level
            - log_format: Python Logging format
            - seed: the seed to be used for pytorch, python, and numpy
            - mixed_precision: Whether or not to enable apex mixed precision
            - mixed_precision_args: apex mixed precision arguments.
                                    See "amp.initialize"
            - sample_transform: Transform acting on the training samples. To be used
                                additively after default transform or auto-augment.
            - target_transform: Transform acting on the training targets.
            - replicas_per_sample: Number of replicas to create per sample in the batch.
                                   (each replica is transformed independently)
                                   Used in maxup.
            - train_model_func: Optional user defined function to train the model,
                                expected to behave similarly to `train_model`
                                in terms of input parameters and return values
            - evaluate_model_func: Optional user defined function to validate the model
                                   expected to behave similarly to `evaluate_model`
                                   in terms of input parameters and return values
            - checkpoint_file: if not None, will start from this model. The model
                               must have the same model_args and model_class as the
                               current experiment.
            - resize_buffers_for_checkpoint: if True, this will resize the model
                                             buffers to match those in the checkpoint.
                                             This is helpful for loading buffers with
                                             sparse levels not matching the model_args
            - checkpoint_at_init: boolean argument for whether to create a checkpoint
                                  of the initialized model. this differs from
                                  `checkpoint_at_start` for which the checkpoint occurs
                                  after the first epoch of training as opposed to
                                  before it
            - epochs_to_validate: list of epochs to run validate(). A -1 asks
                                  to run validate before any training occurs.
                                  Default: last three epochs.
            - extra_validations_per_epoch: number of additional validations to
                                           perform mid-epoch. Additional
                                           validations are distributed evenly
                                           across training batches.
            - launch_time: time the config was created (via time.time). Used to report
                           wall clock time until the first batch is done.
                           Default: time.time() in this setup_experiment().
        """
        # Configure logging related stuff
        log_format = config.get("log_format", logging.BASIC_FORMAT)
        log_level = getattr(logging, config.get("log_level", "INFO").upper())
        console = logging.StreamHandler()
        console.setFormatter(logging.Formatter(log_format))
        self.logger = logging.getLogger(config.get("name",
                                                   type(self).__name__))
        self.logger.setLevel(log_level)
        self.logger.addHandler(console)
        self.progress = config.get("progress", False)
        self.launch_time = config.get("launch_time", time.time())
        self.logdir = config.get("logdir", None)

        # Configure seed
        self.seed = config.get("seed", self.seed)
        set_random_seed(self.seed, False)

        # Configure distribute pytorch
        self.distributed = config.get("distributed", False)
        self.rank = config.get("rank", 0)

        if self.rank == 0:
            self.logger.info(
                f"Execution order: {pformat(self.get_execution_order())}")

        if self.distributed:
            dist_url = config.get("dist_url", "tcp://127.0.0.1:54321")
            backend = config.get("backend", "nccl")
            world_size = config.get("world_size", 1)
            dist.init_process_group(
                backend=backend,
                init_method=dist_url,
                rank=self.rank,
                world_size=world_size,
            )
            # Only enable logs from first process
            self.logger.disabled = self.rank != 0
            self.progress = self.progress and self.rank == 0

        # Configure model
        self.device = config.get("device", self.device)
        self.model = self.create_model(config, self.device)
        if self.rank == 0:
            self.logger.debug(self.model)

        # Configure optimizer
        group_decay, group_no_decay = [], []
        for module in self.model.modules():
            for name, param in module.named_parameters(recurse=False):
                if self.should_decay_parameter(module, name, param, config):
                    group_decay.append(param)
                else:
                    group_no_decay.append(param)

        optimizer_class = config.get("optimizer_class", torch.optim.SGD)
        optimizer_args = config.get("optimizer_args", {})
        self.optimizer = optimizer_class([
            dict(params=group_decay),
            dict(params=group_no_decay, weight_decay=0.)
        ], **optimizer_args)

        # Validate mixed precision requirements
        self.mixed_precision = config.get("mixed_precision", False)
        if self.mixed_precision and amp is None:
            self.mixed_precision = False
            self.logger.error(
                "Mixed precision requires NVIDA APEX."
                "Please install apex from https://www.github.com/nvidia/apex"
                "Disabling mixed precision training.")

        # Configure mixed precision training
        if self.mixed_precision:
            amp_args = config.get("mixed_precision_args", {})
            self.model, self.optimizer = amp.initialize(
                self.model, self.optimizer, **amp_args)
            self.logger.info("Using mixed precision")

        # Apply DistributedDataParallel after all other model mutations
        if self.distributed:
            self.model = DistributedDataParallel(self.model)
        else:
            self.model = DataParallel(self.model)

        self._loss_function = config.get("loss_function",
                                         torch.nn.functional.cross_entropy)

        self.num_classes = config.get("num_classes", 1000)
        self.epochs = config.get("epochs", 1)
        self.batches_in_epoch = config.get("batches_in_epoch", sys.maxsize)
        self.current_epoch = 0

        # Get initial batch size
        self.batch_size = config.get("batch_size", 1)

        # CUDA runtime does not support the fork start method.
        # See https://pytorch.org/docs/stable/notes/multiprocessing.html
        multiprocessing.set_start_method("spawn", force=True)

        # Configure data loaders
        self.train_loader = self.create_train_dataloader(config)
        self.val_loader = self.create_validation_dataloader(config)
        self.total_batches = len(self.train_loader)

        self.epochs_to_validate = config.get(
            "epochs_to_validate", range(self.epochs - 3, self.epochs + 1))

        extra_validations = config.get("extra_validations_per_epoch", 0)
        batches_to_validate = np.linspace(
            min(self.total_batches, self.batches_in_epoch),
            0,
            1 + extra_validations,
            endpoint=False)[::-1].round().astype("int").tolist()
        self.additional_batches_to_validate = batches_to_validate[:-1]
        if extra_validations > 0:
            self.logger.info(
                f"Extra validations per epoch: {extra_validations}, "
                f"batch indices: {self.additional_batches_to_validate}")

        # Used for logging. Conceptually, it is a version number for the model's
        # parameters. By default, this is the elapsed number of batches that the
        # model has been trained on. Experiments may also increment this on
        # other events like model prunings. When validation is performed after a
        # training batch, the validation results are assigned to the next
        # timestep after that training batch, since it was performed on the
        # subsequent version of the parameters.
        self.current_timestep = 0
        self.log_timestep_freq = config.get("log_timestep_freq", 1)

        # A list of [(timestep, result), ...] for the current epoch.
        self.extra_val_results = []

        # Configure learning rate scheduler
        lr_scheduler_class = config.get("lr_scheduler_class", None)
        if lr_scheduler_class is not None:
            lr_scheduler_args = config.get("lr_scheduler_args", {})
            self.logger.info("LR Scheduler args:")
            self.logger.info(pformat(lr_scheduler_args))
            self.logger.info("steps_per_epoch=%s", self.total_batches)
            self.lr_scheduler = create_lr_scheduler(
                optimizer=self.optimizer,
                lr_scheduler_class=lr_scheduler_class,
                lr_scheduler_args=lr_scheduler_args,
                steps_per_epoch=self.total_batches)

        # Set train and validate methods.
        self.train_model = config.get("train_model_func", train_model)
        self.evaluate_model = config.get("evaluate_model_func", evaluate_model)
Exemple #5
0
class GradientDescentTrainer(Trainer):
    def __init__(
        self,
        model: Model,
        optimizer: torch.optim.Optimizer,
        data_loader: torch.utils.data.DataLoader,
        patience: Optional[int] = None,
        validation_metric: str = "-loss",
        validation_data_loader: torch.utils.data.DataLoader = None,
        num_epochs: int = 20,
        serialization_dir: Optional[str] = None,
        checkpointer: Checkpointer = None,
        model_save_interval: float = None,
        cuda_device: int = -1,
        grad_norm: Optional[float] = None,
        grad_clipping: Optional[float] = None,
        learning_rate_scheduler: Optional[LearningRateScheduler] = None,
        momentum_scheduler: Optional[MomentumScheduler] = None,
        tensorboard_writer: TensorboardWriter = None,
        log_batch_size_period: Optional[int] = None,
        moving_average: Optional[MovingAverage] = None,
        distributed: bool = False,
        local_rank: int = 0,
        world_size: int = 1,
        num_gradient_accumulation_steps: int = 1,
        opt_level: Optional[str] = None,
    ) -> None:
        """
        A trainer for doing supervised learning. It just takes a labeled dataset
        and a `DataLoader`, and uses the supplied `Optimizer` to learn the weights
        for your model over some fixed number of epochs. You can also pass in a validation
        dataloader and enable early stopping. There are many other bells and whistles as well.

        # Parameters

        model : `Model`, required.
            An AllenNLP model to be optimized. Pytorch Modules can also be optimized if
            their `forward` method returns a dictionary with a "loss" key, containing a
            scalar tensor representing the loss function to be optimized.

            If you are training your model using GPUs, your model should already be
            on the correct device. (If you use `Trainer.from_params` this will be
            handled for you.)
        optimizer : `torch.nn.Optimizer`, required.
            An instance of a Pytorch Optimizer, instantiated with the parameters of the
            model to be optimized.
        data_loader : `DataLoader`, required.
            A pytorch `DataLoader` containing your `Dataset`, yielding padded indexed batches.
        patience : Optional[int] > 0, optional (default=None)
            Number of epochs to be patient before early stopping: the training is stopped
            after `patience` epochs with no improvement. If given, it must be `> 0`.
            If None, early stopping is disabled.
        validation_metric : str, optional (default="loss")
            Validation metric to measure for whether to stop training using patience
            and whether to serialize an `is_best` model each epoch. The metric name
            must be prepended with either "+" or "-", which specifies whether the metric
            is an increasing or decreasing function.
        validation_dataloader : `DataLoader`, optional (default=None)
            A `DataLoader` to use for the validation set.  If `None`, then
            use the training `DataLoader` with the validation data.
        num_epochs : int, optional (default = 20)
            Number of training epochs.
        serialization_dir : str, optional (default=None)
            Path to directory for saving and loading model files. Models will not be saved if
            this parameter is not passed.
        checkpointer : `Checkpointer`, optional (default=None)
            A `Checkpointer` is responsible for periodically saving model weights.  If none is given
            here, we will construct one with default parameters.
        model_save_interval : `float`, optional (default=None)
            If provided, then serialize models every `model_save_interval`
            seconds within single epochs.  In all cases, models are also saved
            at the end of every epoch if `serialization_dir` is provided.
        cuda_device : `int`, optional (default = -1)
            An integer specifying the CUDA device(s) to use for this process. If -1, the CPU is used.
            Data parallelism is controlled at the allennlp train level, so each trainer will have a single
            GPU.
        grad_norm : `float`, optional, (default = None).
            If provided, gradient norms will be rescaled to have a maximum of this value.
        grad_clipping : `float`, optional (default = `None`).
            If provided, gradients will be clipped `during the backward pass` to have an (absolute)
            maximum of this value.  If you are getting `NaNs` in your gradients during training
            that are not solved by using `grad_norm`, you may need this.
        learning_rate_scheduler : `LearningRateScheduler`, optional (default = None)
            If specified, the learning rate will be decayed with respect to
            this schedule at the end of each epoch (or batch, if the scheduler implements
            the `step_batch` method). If you use `torch.optim.lr_scheduler.ReduceLROnPlateau`,
            this will use the `validation_metric` provided to determine if learning has plateaued.
            To support updating the learning rate on every batch, this can optionally implement
            `step_batch(batch_num_total)` which updates the learning rate given the batch number.
        momentum_scheduler : `MomentumScheduler`, optional (default = None)
            If specified, the momentum will be updated at the end of each batch or epoch
            according to the schedule.
        tensorboard_writer : `TensorboardWriter`, optional
            If this is not provided, we will construct a `TensorboardWriter` with default
            parameters and use that.
        log_batch_size_period : `int`, optional, (default = `None`)
            If defined, how often to log the average batch size.
        moving_average : `MovingAverage`, optional, (default = None)
            If provided, we will maintain moving averages for all parameters. During training, we
            employ a shadow variable for each parameter, which maintains the moving average. During
            evaluation, we backup the original parameters and assign the moving averages to corresponding
            parameters. Be careful that when saving the checkpoint, we will save the moving averages of
            parameters. This is necessary because we want the saved model to perform as well as the validated
            model if we load it later. But this may cause problems if you restart the training from checkpoint.
        distributed : `bool`, optional, (default = False)
            If set, PyTorch's `DistributedDataParallel` is used to train the model in multiple GPUs. This also
            requires `world_size` to be greater than 1.
        local_rank : `int`, optional, (default = 0)
            This is the unique identifier of the `Trainer` in a distributed process group. The GPU device id is
            used as the rank.
        world_size : `int`, (default = 1)
            The number of `Trainer` workers participating in the distributed training.
        num_gradient_accumulation_steps : `int`, optional, (default = 1)
            Gradients are accumulated for the given number of steps before doing an optimizer step. This can
            be useful to accommodate batches that are larger than the RAM size. Refer Thomas Wolf's
            [post](https://tinyurl.com/y5mv44fw) for details on Gradient Accumulation.
        opt_level : `str`, optional, (default = `None`)
            Each opt_level establishes a set of properties that govern Amp’s implementation of pure or mixed
            precision training. Must be a choice of `"O0"`, `"O1"`, `"O2"`, or `"O3"`.
            See the Apex [documentation](https://nvidia.github.io/apex/amp.html#opt-levels-and-properties) for
            more details. If `None`, Amp is not used. Defaults to `None`.
        """
        super().__init__(serialization_dir, cuda_device, distributed,
                         local_rank, world_size)

        # I am not calling move_to_gpu here, because if the model is
        # not already on the GPU then the optimizer is going to be wrong.
        self.model = model

        self.data_loader = data_loader
        self._validation_data_loader = validation_data_loader
        self.optimizer = optimizer

        if patience is None:  # no early stopping
            if validation_data_loader:
                logger.warning(
                    "You provided a validation dataset but patience was set to None, "
                    "meaning that early stopping is disabled")
        elif (not isinstance(patience, int)) or patience <= 0:
            raise ConfigurationError(
                '{} is an invalid value for "patience": it must be a positive integer '
                "or None (if you want to disable early stopping)".format(
                    patience))

        # For tracking is_best_so_far and should_stop_early
        self._metric_tracker = MetricTracker(patience, validation_metric)
        # Get rid of + or -
        self._validation_metric = validation_metric[1:]

        self._num_epochs = num_epochs

        if checkpointer is not None:
            self._checkpointer = checkpointer
        else:
            self._checkpointer = Checkpointer(serialization_dir)

        self._model_save_interval = model_save_interval

        self._grad_norm = grad_norm
        self._grad_clipping = grad_clipping

        self._learning_rate_scheduler = learning_rate_scheduler
        self._momentum_scheduler = momentum_scheduler
        self._moving_average = moving_average

        # We keep the total batch number as an instance variable because it
        # is used inside a closure for the hook which logs activations in
        # `_enable_activation_logging`.
        self._batch_num_total = 0

        self._tensorboard = tensorboard_writer or TensorboardWriter(
            serialization_dir)
        self._tensorboard.get_batch_num_total = lambda: self._batch_num_total
        self._tensorboard.enable_activation_logging(self.model)

        self._log_batch_size_period = log_batch_size_period

        self._last_log = 0.0  # time of last logging

        self._num_gradient_accumulation_steps = num_gradient_accumulation_steps

        # Enable automatic mixed precision training with NVIDIA Apex.
        self._opt_level = opt_level
        if self._opt_level is not None:
            if amp is None:
                raise ConfigurationError((
                    "Apex not installed but opt_level was provided. Please install NVIDIA's Apex to enable"
                    " automatic mixed precision (AMP) training. See: https://github.com/NVIDIA/apex."
                ))

            self.model, self.optimizer = amp.initialize(
                self.model, self.optimizer, opt_level=self._opt_level)

        # Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its
        # usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model`
        # will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc.
        #
        # Hence a reference to Pytorch's object is maintained in the case of distributed training and in the
        # normal case, reference to `Model` is retained. This reference is only used in
        # these places: `model.__call__`, `model.train` and `model.eval`.
        if self._distributed:
            self._pytorch_model = DistributedDataParallel(
                self.model,
                device_ids=[self.cuda_device],
                find_unused_parameters=True)
        else:
            self._pytorch_model = self.model

    def rescale_gradients(self) -> Optional[float]:
        """
        Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled.
        """
        if self._grad_norm:
            if self._opt_level is not None:
                # See: https://nvidia.github.io/apex/advanced.html#gradient-clipping
                parameters_to_clip = [
                    p for p in amp.master_params(self.optimizer)
                    if p.grad is not None
                ]
            else:
                parameters_to_clip = [
                    p for p in self.model.parameters() if p.grad is not None
                ]
            return training_util.sparse_clip_norm(parameters_to_clip,
                                                  self._grad_norm)
        else:
            return None

    def batch_loss(self, batch: TensorDict,
                   for_training: bool) -> torch.Tensor:
        """
        Does a forward pass on the given batches and returns the `loss` value in the result.
        If `for_training` is `True` also applies regularization penalty.
        """
        batch = nn_util.move_to_device(batch, self.cuda_device)
        output_dict = self._pytorch_model(**batch)

        try:
            loss = output_dict["loss"]
            if for_training:
                loss += self.model.get_regularization_penalty()
        except KeyError:
            if for_training:
                raise RuntimeError(
                    "The model you are trying to optimize does not contain a"
                    " 'loss' key in the output of model.forward(inputs).")
            loss = None

        return loss

    def _train_epoch(self, epoch: int) -> Dict[str, float]:
        """
        Trains one epoch and returns metrics.
        """
        logger.info("Epoch %d/%d", epoch, self._num_epochs - 1)
        peak_cpu_usage = common_util.peak_memory_mb()
        logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}")
        gpu_usage = []
        for gpu, memory in common_util.gpu_memory_mb().items():
            gpu_usage.append((gpu, memory))
            logger.info(f"GPU {gpu} memory usage MB: {memory}")

        train_loss = 0.0
        # Set the model to "train" mode.
        self._pytorch_model.train()

        # Get tqdm for the training batches
        batch_generator = iter(self.data_loader)
        batch_group_generator = common_util.lazy_groups_of(
            batch_generator, self._num_gradient_accumulation_steps)

        logger.info("Training")

        num_training_batches = math.ceil(
            len(self.data_loader) / self._num_gradient_accumulation_steps)
        # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the master's
        # progress is shown
        if self._master:
            batch_group_generator_tqdm = Tqdm.tqdm(batch_group_generator,
                                                   total=num_training_batches)
        else:
            batch_group_generator_tqdm = batch_group_generator

        self._last_log = time.time()
        last_save_time = time.time()

        batches_this_epoch = 0
        if self._batch_num_total is None:
            self._batch_num_total = 0

        histogram_parameters = set(
            self.model.get_parameters_for_histogram_tensorboard_logging())

        cumulative_batch_group_size = 0
        done_early = False
        for batch_group in batch_group_generator_tqdm:
            if self._distributed:
                # Check whether the other workers have stopped already (due to differing amounts of
                # data in each). If so, we can't proceed because we would hang when we hit the
                # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor
                # here because NCCL process groups apparently don't support BoolTensor.
                done = torch.tensor(0, device=self.cuda_device)
                torch.distributed.all_reduce(done,
                                             torch.distributed.ReduceOp.SUM)
                if done.item() > 0:
                    done_early = True
                    logger.warning(
                        f"Worker {torch.distributed.get_rank()} finishing training early! "
                        "This implies that there is an imbalance in your training "
                        "data across the workers and that some amount of it will be "
                        "ignored. A small amount of this is fine, but a major imbalance "
                        "should be avoided. Note: This warning will appear unless your "
                        "data is perfectly balanced.")
                    break

            batches_this_epoch += 1
            self._batch_num_total += 1
            batch_num_total = self._batch_num_total

            self.optimizer.zero_grad()

            for batch in batch_group:
                loss = self.batch_loss(batch, for_training=True)
                if torch.isnan(loss):
                    raise ValueError("nan loss encountered")
                loss = loss / len(batch_group)
                if self._opt_level is not None:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                train_loss += loss.item()

            batch_grad_norm = self.rescale_gradients()

            # This does nothing if batch_num_total is None or you are using a
            # scheduler which doesn't update per batch.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step_batch(batch_num_total)
            if self._momentum_scheduler:
                self._momentum_scheduler.step_batch(batch_num_total)

            if self._tensorboard.should_log_histograms_this_batch(
            ) and self._master:
                # get the magnitude of parameter updates for logging
                # We need a copy of current parameters to compute magnitude of updates,
                # and copy them to CPU so large models won't go OOM on the GPU.
                param_updates = {
                    name: param.detach().cpu().clone()
                    for name, param in self.model.named_parameters()
                }
                self.optimizer.step()
                for name, param in self.model.named_parameters():
                    param_updates[name].sub_(param.detach().cpu())
                    update_norm = torch.norm(param_updates[name].view(-1))
                    param_norm = torch.norm(param.view(-1)).cpu()
                    self._tensorboard.add_train_scalar(
                        "gradient_update/" + name,
                        update_norm /
                        (param_norm +
                         nn_util.tiny_value_of_dtype(param_norm.dtype)),
                    )
            else:
                self.optimizer.step()

            # Update moving averages
            if self._moving_average is not None:
                self._moving_average.apply(batch_num_total)

            # Update the description with the latest metrics
            metrics = training_util.get_metrics(
                self.model,
                train_loss,
                batches_this_epoch,
                world_size=self._world_size,
                cuda_device=[self.cuda_device],
            )

            # Updating tqdm only for the master as the trainers wouldn't have one
            if self._master:
                description = training_util.description_from_metrics(metrics)
                batch_group_generator_tqdm.set_description(description,
                                                           refresh=False)

            # Log parameter values to Tensorboard (only from the master)
            if self._tensorboard.should_log_this_batch() and self._master:
                self._tensorboard.log_parameter_and_gradient_statistics(
                    self.model, batch_grad_norm)
                self._tensorboard.log_learning_rates(self.model,
                                                     self.optimizer)

                self._tensorboard.add_train_scalar("loss/loss_train",
                                                   metrics["loss"])
                self._tensorboard.log_metrics(
                    {"epoch_metrics/" + k: v
                     for k, v in metrics.items()})

            if self._tensorboard.should_log_histograms_this_batch(
            ) and self._master:
                self._tensorboard.log_histograms(self.model,
                                                 histogram_parameters)

            if self._log_batch_size_period:
                batch_group_size = sum(
                    training_util.get_batch_size(batch)
                    for batch in batch_group)
                cumulative_batch_group_size += batch_group_size
                if (batches_this_epoch - 1) % self._log_batch_size_period == 0:
                    average = cumulative_batch_group_size / batches_this_epoch
                    logger.info(
                        f"current batch size: {batch_group_size} mean batch size: {average}"
                    )
                    self._tensorboard.add_train_scalar("current_batch_size",
                                                       batch_group_size)
                    self._tensorboard.add_train_scalar("mean_batch_size",
                                                       average)

            # Save model if needed.
            if (self._model_save_interval is not None and
                (time.time() - last_save_time > self._model_save_interval)
                    and self._master):
                last_save_time = time.time()
                self._save_checkpoint("{0}.{1}".format(
                    epoch, training_util.time_to_str(int(last_save_time))))
        if self._distributed and not done_early:
            logger.warning(
                f"Worker {torch.distributed.get_rank()} completed its entire epoch (training)."
            )
            # Indicate that we're done so that any workers that have remaining data stop the epoch early.
            done = torch.tensor(1, device=self.cuda_device)
            torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
            assert done.item()

        # Let all workers finish their epoch before computing
        # the final statistics for the epoch.
        if self._distributed:
            dist.barrier()

        metrics = training_util.get_metrics(
            self.model,
            train_loss,
            batches_this_epoch,
            reset=True,
            world_size=self._world_size,
            cuda_device=[self.cuda_device],
        )
        metrics["cpu_memory_MB"] = peak_cpu_usage
        for (gpu_num, memory) in gpu_usage:
            metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory
        return metrics

    def _validation_loss(self) -> Tuple[float, int]:
        """
        Computes the validation loss. Returns it and the number of batches.
        """
        logger.info("Validating")

        self._pytorch_model.eval()

        # Replace parameter values with the shadow values from the moving averages.
        if self._moving_average is not None:
            self._moving_average.assign_average_value()

        if self._validation_data_loader is not None:
            validation_data_loader = self._validation_data_loader
        else:
            raise ConfigurationError(
                "Validation results cannot be calculated without a validation_data_loader"
            )

        val_generator_tqdm = Tqdm.tqdm(validation_data_loader)
        batches_this_epoch = 0
        val_loss = 0
        done_early = False
        for batch in val_generator_tqdm:
            if self._distributed:
                # Check whether the other workers have stopped already (due to differing amounts of
                # data in each). If so, we can't proceed because we would hang when we hit the
                # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor
                # here because NCCL process groups apparently don't support BoolTensor.
                done = torch.tensor(0, device=self.cuda_device)
                torch.distributed.all_reduce(done,
                                             torch.distributed.ReduceOp.SUM)
                if done.item() > 0:
                    done_early = True
                    logger.warning(
                        f"Worker {torch.distributed.get_rank()} finishing validation early! "
                        "This implies that there is an imbalance in your validation "
                        "data across the workers and that some amount of it will be "
                        "ignored. A small amount of this is fine, but a major imbalance "
                        "should be avoided. Note: This warning will appear unless your "
                        "data is perfectly balanced.")
                    break

            loss = self.batch_loss(batch, for_training=False)
            if loss is not None:
                # You shouldn't necessarily have to compute a loss for validation, so we allow for
                # `loss` to be None.  We need to be careful, though - `batches_this_epoch` is
                # currently only used as the divisor for the loss function, so we can safely only
                # count those batches for which we actually have a loss.  If this variable ever
                # gets used for something else, we might need to change things around a bit.
                batches_this_epoch += 1
                val_loss += loss.detach().cpu().numpy()

            # Update the description with the latest metrics
            val_metrics = training_util.get_metrics(
                self.model,
                val_loss,
                batches_this_epoch,
                world_size=self._world_size,
                cuda_device=[self.cuda_device],
            )
            description = training_util.description_from_metrics(val_metrics)
            val_generator_tqdm.set_description(description, refresh=False)

        if self._distributed and not done_early:
            logger.warning(
                f"Worker {torch.distributed.get_rank()} completed its entire epoch (validation)."
            )
            # Indicate that we're done so that any workers that have remaining data stop validation early.
            done = torch.tensor(1, device=self.cuda_device)
            torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
            assert done.item()

        # Now restore the original parameter values.
        if self._moving_average is not None:
            self._moving_average.restore()

        return val_loss, batches_this_epoch

    def train(self) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        """
        try:
            epoch_counter = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing serialization "
                "directory?")

        training_util.enable_gradient_clipping(self.model, self._grad_clipping)

        logger.info("Beginning training.")

        val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        metrics["best_epoch"] = self._metric_tracker.best_epoch
        for key, value in self._metric_tracker.best_epoch_metrics.items():
            metrics["best_validation_" + key] = value

        for epoch in range(epoch_counter, self._num_epochs):
            epoch_start_time = time.time()
            train_metrics = self._train_epoch(epoch)

            # get peak of memory usage
            if "cpu_memory_MB" in train_metrics:
                metrics["peak_cpu_memory_MB"] = max(
                    metrics.get("peak_cpu_memory_MB", 0),
                    train_metrics["cpu_memory_MB"])
            for key, value in train_metrics.items():
                if key.startswith("gpu_"):
                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0),
                                                 value)

            if self._validation_data_loader is not None:
                with torch.no_grad():
                    # We have a validation set, so compute all the metrics on it.
                    val_loss, num_batches = self._validation_loss()

                    # It is safe again to wait till the validation is done. This is
                    # important to get the metrics right.
                    if self._distributed:
                        dist.barrier()

                    val_metrics = training_util.get_metrics(
                        self.model,
                        val_loss,
                        num_batches,
                        reset=True,
                        world_size=self._world_size,
                        cuda_device=[self.cuda_device],
                    )

                    # Check validation metric for early stopping
                    this_epoch_val_metric = val_metrics[
                        self._validation_metric]
                    self._metric_tracker.add_metric(this_epoch_val_metric)

                    if self._metric_tracker.should_stop_early():
                        logger.info("Ran out of patience.  Stopping training.")
                        break

            if self._master:
                self._tensorboard.log_metrics(
                    train_metrics,
                    val_metrics=val_metrics,
                    log_to_console=True,
                    epoch=epoch + 1)  # +1 because tensorboard doesn't like 0

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = str(
                datetime.timedelta(seconds=training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch

            for key, value in train_metrics.items():
                metrics["training_" + key] = value
            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

            if self._metric_tracker.is_best_so_far():
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics["best_epoch"] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

                self._metric_tracker.best_epoch_metrics = val_metrics

            if self._serialization_dir and self._master:
                common_util.dump_metrics(
                    os.path.join(self._serialization_dir,
                                 f"metrics_epoch_{epoch}.json"), metrics)

            # The Scheduler API is agnostic to whether your schedule requires a validation metric -
            # if it doesn't, the validation metric passed here is ignored.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step(this_epoch_val_metric)
            if self._momentum_scheduler:
                self._momentum_scheduler.step(this_epoch_val_metric)

            if self._master:
                self._save_checkpoint(epoch)

            # Wait for the master to finish saving the checkpoint
            if self._distributed:
                dist.barrier()

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info("Epoch duration: %s",
                        datetime.timedelta(seconds=epoch_elapsed_time))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * (
                    (self._num_epochs - epoch_counter) /
                    float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(
                    datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s",
                            formatted_time)

            epochs_trained += 1

        # make sure pending events are flushed to disk and files are closed properly
        self._tensorboard.close()

        # Load the best model state before returning
        best_model_state = self._checkpointer.best_model_state()
        if best_model_state:
            self.model.load_state_dict(best_model_state)

        return metrics

    def _save_checkpoint(self, epoch: Union[int, str]) -> None:
        """
        Saves a checkpoint of the model to self._serialization_dir.
        Is a no-op if self._serialization_dir is None.

        # Parameters

        epoch : Union[int, str], required.
            The epoch of training.  If the checkpoint is saved in the middle
            of an epoch, the parameter is a string with the epoch and timestamp.
        """
        # If moving averages are used for parameters, we save
        # the moving average values into checkpoint, instead of the current values.
        if self._moving_average is not None:
            self._moving_average.assign_average_value()

        # These are the training states we need to persist.
        training_states = {
            "metric_tracker": self._metric_tracker.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "batch_num_total": self._batch_num_total,
        }

        # If we have a learning rate or momentum scheduler, we should persist them too.
        if self._learning_rate_scheduler is not None:
            training_states[
                "learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict(
                )
        if self._momentum_scheduler is not None:
            training_states[
                "momentum_scheduler"] = self._momentum_scheduler.state_dict()

        self._checkpointer.save_checkpoint(
            model_state=self.model.state_dict(),
            epoch=epoch,
            training_states=training_states,
            is_best_so_far=self._metric_tracker.is_best_so_far(),
        )

        # Restore the original values for parameters so that training will not be affected.
        if self._moving_average is not None:
            self._moving_average.restore()

    def _restore_checkpoint(self) -> int:
        """
        Restores the model and training state from the last saved checkpoint.
        This includes an epoch count and optimizer state, which is serialized separately
        from model parameters. This function should only be used to continue training -
        if you wish to load a model for inference/load parts of a model into a new
        computation graph, you should use the native Pytorch functions:
        ` model.load_state_dict(torch.load("/path/to/model/weights.th"))`

        If `self._serialization_dir` does not exist or does not contain any checkpointed weights,
        this function will do nothing and return 0.

        # Returns

        epoch: int
            The epoch at which to resume training, which should be one after the epoch
            in the saved training state.
        """
        model_state, training_state = self._checkpointer.restore_checkpoint()

        if not training_state:
            # No checkpoint to restore, start at 0
            return 0

        self.model.load_state_dict(model_state)
        self.optimizer.load_state_dict(training_state["optimizer"])
        if (self._learning_rate_scheduler is not None
                and "learning_rate_scheduler" in training_state):
            self._learning_rate_scheduler.load_state_dict(
                training_state["learning_rate_scheduler"])
        if self._momentum_scheduler is not None and "momentum_scheduler" in training_state:
            self._momentum_scheduler.load_state_dict(
                training_state["momentum_scheduler"])
        training_util.move_optimizer_to_cuda(self.optimizer)

        # Currently the `training_state` contains a serialized `MetricTracker`.
        if "metric_tracker" in training_state:
            self._metric_tracker.load_state_dict(
                training_state["metric_tracker"])
        # It used to be the case that we tracked `val_metric_per_epoch`.
        elif "val_metric_per_epoch" in training_state:
            self._metric_tracker.clear()
            self._metric_tracker.add_metrics(
                training_state["val_metric_per_epoch"])
        # And before that we didn't track anything.
        else:
            self._metric_tracker.clear()

        if isinstance(training_state["epoch"], int):
            epoch_to_return = training_state["epoch"] + 1
        else:
            epoch_to_return = int(training_state["epoch"].split(".")[0]) + 1

        # For older checkpoints with batch_num_total missing, default to old behavior where
        # it is unchanged.
        batch_num_total = training_state.get("batch_num_total")
        if batch_num_total is not None:
            self._batch_num_total = batch_num_total

        return epoch_to_return

    @classmethod
    def from_partial_objects(
        cls,
        model: Model,
        serialization_dir: str,
        data_loader: DataLoader,
        validation_data_loader: DataLoader = None,
        local_rank: int = 0,
        patience: int = None,
        validation_metric: str = "-loss",
        num_epochs: int = 20,
        cuda_device: int = -1,
        grad_norm: float = None,
        grad_clipping: float = None,
        model_save_interval: float = None,
        log_batch_size_period: int = None,
        distributed: bool = None,
        world_size: int = 1,
        num_gradient_accumulation_steps: int = 1,
        opt_level: Optional[str] = None,
        no_grad: List[str] = None,
        optimizer: Lazy[Optimizer] = None,
        learning_rate_scheduler: Lazy[LearningRateScheduler] = None,
        momentum_scheduler: Lazy[MomentumScheduler] = None,
        tensorboard_writer: Lazy[TensorboardWriter] = None,
        moving_average: Lazy[MovingAverage] = None,
        checkpointer: Lazy[Checkpointer] = None,
    ) -> "Trainer":
        """
        This method exists so that we can have a documented method to construct this class using
        `FromParams`. If you are not using `FromParams` or config files, you can safely ignore this
        method.

        The reason we can't just use `__init__` with `FromParams` here is because there are
        sequential dependencies to this class's arguments.  Anything that has a `Lazy[]` type
        annotation needs something from one of the non-`Lazy` arguments.  The `Optimizer` needs to
        have the parameters from the `Model` before it's constructed, and the `Schedulers` need to
        have the `Optimizer`. Because of this, the typical way we construct things `FromParams`
        doesn't work, so we use `Lazy` to allow for constructing the objects sequentially.

        If you're not using `FromParams`, you can just construct these arguments in the right order
        yourself in your code and call the constructor directly.
        """

        check_for_gpu(cuda_device)
        if cuda_device >= 0:
            # Moving model to GPU here so that the optimizer state gets constructed on
            # the right device.
            model = model.cuda(cuda_device)

        if no_grad:
            for name, parameter in model.named_parameters():
                if any(re.search(regex, name) for regex in no_grad):
                    parameter.requires_grad_(False)

        common_util.log_frozen_and_tunable_parameter_names(model)

        parameters = [[n, p] for n, p in model.named_parameters()
                      if p.requires_grad]
        optimizer_ = optimizer.construct(model_parameters=parameters)
        if not optimizer_:
            optimizer_ = Optimizer.default(parameters)

        try:
            batches_per_epoch = len(data_loader)
        except TypeError:
            # If the dataset is lazy, it won't have a length.
            batches_per_epoch = None

        moving_average_ = moving_average.construct(parameters=parameters)
        learning_rate_scheduler_ = learning_rate_scheduler.construct(
            optimizer=optimizer_,
            num_epochs=num_epochs,
            num_steps_per_epoch=batches_per_epoch)
        momentum_scheduler_ = momentum_scheduler.construct(
            optimizer=optimizer_)

        checkpointer_ = checkpointer.construct() or Checkpointer(
            serialization_dir)
        tensorboard_writer_ = tensorboard_writer.construct(
        ) or TensorboardWriter(serialization_dir)

        return cls(
            model,
            optimizer_,
            data_loader,
            patience=patience,
            validation_metric=validation_metric,
            validation_data_loader=validation_data_loader,
            num_epochs=num_epochs,
            serialization_dir=serialization_dir,
            cuda_device=cuda_device,
            grad_norm=grad_norm,
            grad_clipping=grad_clipping,
            learning_rate_scheduler=learning_rate_scheduler_,
            momentum_scheduler=momentum_scheduler_,
            tensorboard_writer=tensorboard_writer_,
            checkpointer=checkpointer_,
            model_save_interval=model_save_interval,
            log_batch_size_period=log_batch_size_period,
            moving_average=moving_average_,
            distributed=distributed,
            local_rank=local_rank,
            world_size=world_size,
            num_gradient_accumulation_steps=num_gradient_accumulation_steps,
            opt_level=opt_level,
        )
    def initialize(self, training=True, force_load_plans=False):
        """
        this is a copy of nnUNetTrainerV2's initialize. We only add the regions to the data augmentation
        :param training:
        :param force_load_plans:
        :return:
        """
        if not self.was_initialized:
            maybe_mkdir_p(self.output_folder)

            if force_load_plans or (self.plans is None):
                self.load_plans_file()

            self.process_plans(self.plans)

            self.setup_DA_params()

            self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
                                                      "_stage%d" % self.stage)
            if training:
                self.dl_tr, self.dl_val = self.get_basic_generators()
                if self.unpack_data:
                    if self.local_rank == 0:
                        print("unpacking dataset")
                        unpack_dataset(self.folder_with_preprocessed_data)
                        print("done")
                    else:
                        # we need to wait until worker 0 has finished unpacking
                        npz_files = subfiles(self.folder_with_preprocessed_data, suffix=".npz", join=False)
                        case_ids = [i[:-4] for i in npz_files]
                        all_present = all(
                            [isfile(join(self.folder_with_preprocessed_data, i + ".npy")) for i in case_ids])
                        while not all_present:
                            print("worker", self.local_rank, "is waiting for unpacking")
                            sleep(3)
                            all_present = all(
                                [isfile(join(self.folder_with_preprocessed_data, i + ".npy")) for i in case_ids])
                        # there is some slight chance that there may arise some error because dataloader are loading a file
                        # that is still being written by worker 0. We ignore this for now an address it only if it becomes
                        # relevant
                        # (this can occur because while worker 0 writes the file is technically present so the other workers
                        # will proceed and eventually try to read it)
                else:
                    print(
                        "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                        "will wait all winter for your model to finish!")

                # setting weights for deep supervision losses
                net_numpool = len(self.net_num_pool_op_kernel_sizes)

                # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
                # this gives higher resolution outputs more weight in the loss
                weights = np.array([1 / (2 ** i) for i in range(net_numpool)])

                # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
                mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)])
                weights[~mask] = 0
                weights = weights / weights.sum()
                self.ds_loss_weights = weights

                seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads'))
                seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1))
                print("seeds train", seeds_train)
                print("seeds_val", seeds_val)
                self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val,
                                                                    self.data_aug_params[
                                                                        'patch_size_for_spatialtransform'],
                                                                    self.data_aug_params,
                                                                    deep_supervision_scales=self.deep_supervision_scales,
                                                                    seeds_train=seeds_train,
                                                                    seeds_val=seeds_val,
                                                                    pin_memory=self.pin_memory,
                                                                    regions=self.regions)
                self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
                                       also_print_to_console=False)
                self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
                                       also_print_to_console=False)
            else:
                pass

            self.initialize_network()
            self.initialize_optimizer_and_scheduler()
            self._maybe_init_amp()
            self.network = DDP(self.network, self.local_rank)

        else:
            self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
        self.was_initialized = True
def run_training(rank, args, hp, port=None):
    if args.n_gpus > 1:
        init_distributed(rank, args.n_gpus, port)
        torch.cuda.set_device(f'cuda:{rank}')

    ## NOTE: variable
    model = TransformerWav2vec2(
        hp,
        pretrain_model='facebook/wav2vec2-large-lv60',
        freeze_feature_extractor=hp.freeze_feature_extractor)
    ## TODO: change init_weight (maybe initialize all networks)
    #model.apply(init_weight)
    model.train()

    if rank == 0:
        print(model)

    model = model.to(rank)

    if args.n_gpus > 1:
        model = DDP(torch.nn.SyncBatchNorm.convert_sync_batchnorm(model),
                    device_ids=[rank])

    max_lr = hp.init_lr
    if hp.optimizer_type == 'Noam':
        ## NOTE: scheduling?
        ## NOTE: learning rate?
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=max_lr,
                                     betas=(0.9, 0.98),
                                     eps=1e-9)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=max_lr)

    assert (hp.batch_size is None) != (hp.max_seqlen is None)

    if args.n_gpus > 1:
        dist.barrier()
        # configure map_location properly
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}

    if hp.loaded_epoch is not None:
        start_epoch = hp.loaded_epoch
        load_dir = hp.loaded_dir
        print('epoch {} loaded'.format(hp.loaded_epoch))
        loaded_dict = load_model("{}".format(
            os.path.join(load_dir, 'network.epoch{}'.format(hp.loaded_epoch))),
                                 map_location=map_location)
        model.load_state_dict(loaded_dict)
        if hp.is_flat_start:
            step = 1
            start_epoch = 0
            print('flat_start')
        else:
            loaded_dict = torch.load("{}".format(
                os.path.join(
                    load_dir,
                    'network.optimizer.epoch{}'.format(hp.loaded_epoch))),
                                     map_location=map_location)
            optimizer.load_state_dict(loaded_dict)
            step = loaded_dict['state'][0]['step']
            #lr = get_learning_rate(step//hp.accum_grad+1, hp)
            lr = get_learning_rate_tristage(step // hp.accum_grad + 1)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            del loaded_dict
            torch.cuda.empty_cache()
    else:
        start_epoch = 0
        step = 1

    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print('params = {0:.2f}M'.format(pytorch_total_params / 1000 / 1000))
    train_epoch(model,
                optimizer,
                args,
                hp,
                step=step,
                start_epoch=start_epoch,
                rank=rank)
def train():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_file",
                        type=str,
                        required=True,
                        help='File path to use for train file')
    parser.add_argument("--valid_file",
                        type=str,
                        required=True,
                        help='File path to use for valid file')
    parser.add_argument("--dataset_key",
                        default="paired",
                        help="dataset key for basedir")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        choices=["default", "positional", "learned-positional"],
        help="register label of the embeddings")
    parser.add_argument("--d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument(
        "--d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--windowed_ra",
                        type=str2bool,
                        default=False,
                        help="whether prevent attention beyond rpr_k")
    parser.add_argument("--num_train_workers",
                        type=int,
                        default=4,
                        help="Number train workers")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--tgt_nctx",
                        type=int,
                        help="Max output length, default to args.nctx")
    parser.add_argument("--file_type", default='json', help="Suffix for data")
    parser.add_argument("--record_keys", default=['x', 'y'], nargs='+')
    parser.add_argument("--batch_size",
                        type=int,
                        default=256,
                        help="Batch Size")
    parser.add_argument("--subword_model_file",
                        type=str,
                        help="The BPE model file",
                        required=True)
    parser.add_argument("--subword_vocab_file",
                        type=str,
                        help="The BPE subword vocab",
                        required=True)
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--lr_scheduler",
                        type=str,
                        default='cosine',
                        help="The type of learning rate decay scheduler")
    parser.add_argument("--lr_decay_steps",
                        type=int,
                        help="decay steps of lr scheduler")
    parser.add_argument("--lr_decay_rate",
                        type=float,
                        help="decay rate of lr scheduler")
    parser.add_argument("--lr_alpha",
                        type=float,
                        help="parameter alpha for cosine decay scheduler")
    parser.add_argument("--optim",
                        default="adamw",
                        type=str,
                        help="Optimizer to use (defaults to adamw)")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=1.0e-2,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=32,
                        help="Num training epochs")
    parser.add_argument(
        "--restart_from",
        type=str,
        help="Option allows you to restart from a previous checkpoint")
    parser.add_argument(
        "--restart_tt",
        type=str,
        help="Optional param for legacy checkpoints (step|epoch)")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=10000,
                        help="Num warmup steps")
    parser.add_argument("--saves_per_epoch",
                        type=int,
                        default=10,
                        help="The number of checkpoints to save per epoch")
    parser.add_argument("--reduction_d_k",
                        type=int,
                        default=64,
                        help="Dimensions of Key and Query in the single headed"
                        "reduction layers")
    parser.add_argument(
        "--reduction_type",
        type=str,
        default="2ha",
        help="If using a dual encoder, specifies the reduction type")
    parser.add_argument(
        "--unfreeze_after_step",
        default=0,
        type=int,
        help=
        "Unfreeze encoders after step, ignored if we dont have a checkpoint")
    parser.add_argument(
        "--stacking_layers",
        type=int,
        nargs='+',
        default=[],
        help="Hidden sizes of the dense stack (ff2 from the convert paper)")
    parser.add_argument("--layer_drop",
                        type=float,
                        default=0.0,
                        help="LayerDrop to apply")
    parser.add_argument("--ff_pdrop",
                        type=float,
                        default=0.1,
                        help="Dropout in the dense stack")

    parser.add_argument("--reader_type",
                        type=str,
                        default='preprocessed',
                        choices=['ntp', 'nsp', 'preprocessed', 'tfrecord'])
    parser.add_argument(
        "--model_type",
        default="dual-encoder",
        choices=["dual-encoder", "encoder-decoder", "transformer-bow"])
    parser.add_argument("--src_begin_tok", type=str, nargs='+', default=[])
    parser.add_argument("--src_end_tok",
                        type=str,
                        nargs='+',
                        default=['<EOS>'])
    parser.add_argument("--tgt_begin_tok",
                        type=str,
                        nargs='+',
                        default=['<GO>'])
    parser.add_argument("--tgt_end_tok",
                        type=str,
                        nargs='+',
                        default=['<EOS>'])
    parser.add_argument('--lower', type=baseline.str2bool, default=False)
    parser.add_argument(
        "--loss",
        type=str,
        default='symmetric',
        choices=['triplet', 'all', 'all_mean', 'contrastive', 'symmetric'])
    parser.add_argument(
        "--learn_temp",
        type=str2bool,
        default=True,
        help=
        "If 'constrastive' or 'symmetric' loss, should we learn the temperature scaling"
    )
    parser.add_argument(
        "--init_temp",
        type=float,
        help="Initialize the temperature for 'contrastive' or 'symmetric' loss"
    )
    parser.add_argument(
        '--rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--distributed",
                        type=str2bool,
                        default=False,
                        help="Are we doing distributed training?")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help=
        "Local rank for distributed training (-1 means use the environment variables to find)"
    )
    parser.add_argument("--save_npz",
                        type=str2bool,
                        default=False,
                        help="Whether save npz checkpoint")

    args = parser.parse_args()

    if args.basedir is None:
        args.basedir = '{}-{}-paired-{}-bpe-{}'.format(args.model_type,
                                                       args.reader_type,
                                                       args.dataset_key,
                                                       os.getpid())
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    num_gpus = get_num_gpus_multiworker()
    args.distributed = args.distributed or num_gpus > 1
    logger.info(f"Using {num_gpus} GPUs in this job.")

    if args.distributed:
        args.device, updated_local_rank = init_distributed(args.local_rank)
        args.local_rank = updated_local_rank

    if not args.tgt_nctx:
        args.tgt_nctx = args.nctx
    reader = MultiFileDatasetReader(args.nctx,
                                    args.tgt_nctx,
                                    args.src_begin_tok,
                                    args.src_end_tok,
                                    args.tgt_begin_tok,
                                    args.tgt_end_tok,
                                    args.subword_model_file,
                                    args.subword_vocab_file,
                                    args.file_type,
                                    reader_type=args.reader_type,
                                    record_keys=args.record_keys,
                                    lower=args.lower)

    vocab = reader.build_vocab()
    # If we are not using chars, then use 'x' for both input and output
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=args.embed_type)
    vocabs = preproc_data['vocab']

    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    embeddings = preproc_data['embeddings']
    logger.info("Loaded embeddings")

    train_set = reader.load(args.train_file, vocabs)
    valid_set = reader.load(args.valid_file,
                            vocabs,
                            distribute=False,
                            shuffle=False)

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=args.num_train_workers)
    valid_loader = DataLoader(valid_set, batch_size=args.batch_size)
    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)

    if len(args.rpr_k) == 0 or args.rpr_k[0] < 1:
        rpr_k = None
    elif len(args.rpr_k) == 1:
        rpr_k = args.rpr_k[0]
    else:
        rpr_k = args.rpr_k

    model = create_model(embeddings,
                         d_model=args.d_model,
                         d_ff=args.d_ff,
                         dropout=args.dropout,
                         num_heads=args.num_heads,
                         num_layers=args.num_layers,
                         model_type=args.model_type,
                         rpr_k=rpr_k,
                         d_k=args.d_k,
                         reduction_d_k=args.reduction_d_k,
                         stacking_layers=args.stacking_layers,
                         ff_pdrop=args.ff_pdrop,
                         windowed_ra=args.windowed_ra,
                         reduction_type=args.reduction_type,
                         layer_drop=args.layer_drop,
                         logger=logger)

    model.to(args.device)
    if args.model_type == 'encoder-decoder':
        run_step = run_step_s2s
    else:
        run_step = run_step_dual
        logger.info(
            f"Creating {args.loss}, init temperature: {args.init_temp}, learnable: {args.learn_temp}"
        )
    loss_function = model.create_loss(loss_type=args.loss,
                                      init_temp=args.init_temp,
                                      learn_temp=args.learn_temp)
    loss_function.to(args.device)

    logger.info("Created model and loss")

    steps_per_epoch = len(train_loader) // num_gpus
    valid_steps = len(valid_loader)
    update_on = steps_per_epoch // args.saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(
        f"Steps per epoch per GPU: {steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )
    lr_decay = get_lr_decay(args.lr_scheduler,
                            args.lr,
                            steps_per_epoch,
                            args.epochs,
                            logger,
                            decay_steps=args.lr_decay_steps,
                            decay_rate=args.lr_decay_rate,
                            alpha=args.lr_alpha)
    linear_warmup = WarmupLinearSchedulerPyTorch(args.warmup_steps, lr=args.lr)
    lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=args.lr)

    global_step = 0
    start_epoch = 0

    if args.restart_from:

        if args.unfreeze_after_step > 0 and args.model_type == "dual-encoder":
            logger.info(f"Encoders will be frozen until step %d",
                        args.unfreeze_after_step)
        global_step, start_epoch = reload_from_checkpoint(
            args.model_type, args.restart_from, args.restart_tt, model,
            steps_per_epoch)
        logger.info(
            "Restarting from a previous checkpoint %s.\n\tStarting at global_step=%d, epoch=%d",
            args.restart_from, global_step, start_epoch + 1)

    target = model if args.model_type == 'encoder-decoder' else loss_function

    optimizer = OptimizerManager(target,
                                 global_step,
                                 optim=args.optim,
                                 lr=args.lr,
                                 lr_function=lr_sched,
                                 weight_decay=args.weight_decay)
    logger.info("Model has {:,} parameters".format(
        sum(p.numel() for p in target.parameters() if p.requires_grad)))
    # Prepare model for distributed training if needed
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.device],
                                        output_device=args.device)
        logger.info("Model located on %d", args.local_rank)

    model_base = os.path.join(args.basedir, 'checkpoint')
    steps = global_step
    timer = Timer()

    for epoch in range(start_epoch, args.epochs):
        avg_loss = Average('average_train_loss')
        metrics = {}
        optimizer.zero_grad()
        timer.start()
        model.train()
        train_itr = iter(train_loader)
        for i in range(steps_per_epoch):
            batch = next(train_itr)

            if steps > args.unfreeze_after_step and hasattr(
                    model, 'freeze') and model.freeze:
                logging.info("Unfreezing encoders at step %d", steps)
                model.freeze = False
            steps += 1

            x, y = batch
            loss = run_step(x, y, model, loss_function, args.device,
                            args.distributed)
            loss.backward()
            avg_loss.update(loss.item())

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            optimizer.zero_grad()
            if (i + 1) % report_on == 0:
                logging.info(avg_loss)
            if (i + 1) % update_on == 0 and args.local_rank < 1:
                elapsed = timer.elapsed(True)
                logging.info('elapsed time this epoch %d min', elapsed)
                logging.info('elapsed step time %f steps/min', i / elapsed)
                logging.info('LR: %f', optimizer.current_lr)
                save_checkpoint(model,
                                model_base,
                                steps,
                                tick_type='step',
                                save_npz=args.save_npz)

        # How much time elapsed in minutes
        elapsed = timer.elapsed(True)
        train_avg_loss = avg_loss.avg
        # This is the average training token-level loss across all machines
        # This is the token-level training perplexity
        metrics['train_elapsed_min'] = elapsed
        metrics['average_train_loss'] = train_avg_loss
        if args.local_rank < 1:
            avg_valid_loss = Average('average_valid_loss')
            timer.start()
            model.eval()
            valid_itr = iter(valid_loader)
            for j in range(valid_steps):
                with torch.no_grad():
                    batch = next(valid_itr)
                    x, y = batch
                    loss = run_step(x, y, model, loss_function, args.device,
                                    args.distributed)
                avg_valid_loss.update(loss.item())

            valid_avg_loss = avg_valid_loss.avg

            elapsed = timer.elapsed(True)
            metrics['valid_elapsed_min'] = elapsed

            metrics['average_valid_loss'] = valid_avg_loss
            logger.info(metrics)
            save_checkpoint(model,
                            model_base,
                            epoch,
                            tick_type='epoch',
                            save_npz=args.save_npz)
Exemple #9
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="openai-gpt",
                        help="Path, url or short name of the model")
    parser.add_argument("--num_candidates",
                        type=int,
                        default=2,
                        help="Number of candidates for training")
    parser.add_argument("--max_history",
                        type=int,
                        default=2,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=1.0,
                        help="LM loss coefficient")
    parser.add_argument("--mc_coef",
                        type=float,
                        default=1.0,
                        help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=3,
                        help="Number of training epochs")
    parser.add_argument("--personality_permutations",
                        type=int,
                        default=1,
                        help="Number of permutations of personality sentences")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument(
        "--data_faiss",
        type=str,
        default="data_persona_faiss_fase1_opcion4",
        help=
        "list of the personalities selected with faiss according to the strategy selected"
    )
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer  # cant use Autotokenizer because checkpoint could be a Path
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)

    model_class = GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(args.model_checkpoint)
    model.to(args.device)
    # Add special tokens if they are not already added
    add_special_tokens_(model, tokenizer)
    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, tokenizer)

    # Training function and trainer
    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
        (lm_loss), (mc_loss), *_ = model(input_ids,
                                         token_type_ids=token_type_ids,
                                         mc_token_ids=mc_token_ids,
                                         mc_labels=mc_labels,
                                         lm_labels=lm_labels)
        loss = (lm_loss * args.lm_coef +
                mc_loss * args.mc_coef) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            # if we dont send labels to model, it doesnt return losses
            lm_logits, mc_logits, *_ = model(
                input_ids,
                token_type_ids=token_type_ids,
                mc_token_ids=mc_token_ids,
            )
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-100),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        log_dir = make_logdir(args.model_checkpoint)
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(
            evaluator,
            log_handler=OutputHandler(
                tag="validation",
                metric_names=list(metrics.keys()),
                global_step_transform=global_step_from_engine(trainer)),
            event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_args.bin')
        getattr(model, 'module',
                model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            os.path.join(log_dir, checkpoint_handler._saved[-1][1]),
            os.path.join(log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Exemple #10
0
class Trainer:
    def __init__(self,
                 args,
                 train_loader=None,
                 val_loader=None,
                 logger=None,
                 num_answers=0,
                 train=True):
        self.args = args

        self.max_text_length = args.max_text_length

        self.train_loader = train_loader
        self.val_loader = val_loader
        self.num_answers = num_answers

        self.logger = logger

        # Model
        self.model = VQAModel.from_pretrained("bert-base-uncased",
                                              args=args,
                                              num_answers=self.num_answers)

        self.verbose = True
        if self.args.distributed:
            if self.args.gpu != 0:
                self.verbose = False

        # Load Checkpoint
        self.start_epoch = None
        if args.load is not None:
            path = args.load + '.pth'
            self.load(path, verbose=self.verbose)

        elif args.load_lxmert_qa is not None:
            path = args.load_lxmert_qa + '_LXRT.pth'
            load_lxmert_qa(
                args,
                path,
                self.model,
                label2ans=self.train_loader.dataset.raw_dataset.label2ans,
                verbose=self.verbose)

        # GPU Options
        print(f'Model Launching at GPU {self.args.gpu}')
        from time import time
        start = time()
        self.model.cuda(args.gpu)

        # Optimizer
        if train:
            self.optim, self.lr_scheduler = self.create_optimizer_and_scheduler(
            )
            self.bce_loss = nn.BCEWithLogitsLoss()

        if args.multiGPU:
            assert args.distributed
            self.model = DDP(self.model,
                             device_ids=[args.gpu],
                             find_unused_parameters=True)
        if args.gpu == 0:
            print(f'It took {time() - start:.1f}s')

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def create_optimizer_and_scheduler(self):
        if self.verbose:
            print('Building Optimizer')

        from transformers.optimization import AdamW, get_linear_schedule_with_warmup
        batch_per_epoch = len(self.train_loader)
        t_total = int(batch_per_epoch * self.args.epochs)
        warmup_ratio = self.args.warmp_ratio
        warmup_iters = int(t_total * warmup_ratio)
        if self.verbose:
            print("Batch per epoch: %d" % batch_per_epoch)
            print("Total Iters: %d" % t_total)
            print('Warmup ratio:', warmup_ratio)
            print("Warm up Iters: %d" % warmup_iters)

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                self.args.weight_decay,
            },
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]

        optim = AdamW(optimizer_grouped_parameters, self.args.lr)
        lr_scheduler = get_linear_schedule_with_warmup(optim, warmup_iters,
                                                       t_total)

        return optim, lr_scheduler

    def train(self):
        if self.verbose:
            loss_meter = LossMeter()
            quesid2ans = {}
            best_valid = 0.
            print("Valid Oracle: %0.2f" %
                  (self.oracle_score(self.val_loader) * 100))

            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter(log_dir=self.args.log_dir)

            hparam_dict = {}
            for k, v in self.args.__dict__.items():
                if type(v) in [int, float, str, bool, torch.Tensor]:
                    hparam_dict[k] = v
            metric_dict = {}

            self.writer.add_hparams(hparam_dict, metric_dict)

        if self.args.distributed:
            dist.barrier()

        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=150)

            quesid2ans = {}
            for step_i, batch in enumerate(self.train_loader):
                update = True
                if self.args.update_freq > 1:
                    if step_i == 0:
                        update = False
                    elif step_i % self.args.update_freq == 0 or step_i == len(
                            self.train_loader) - 1:
                        update = True
                    else:
                        update = False

                if self.args.distributed:
                    results = self.model.module.train_step(batch)
                else:
                    results = self.model.train_step(batch)

                vis_feats = batch['vis_feats'].cuda()
                boxes = batch['boxes'].cuda()

                input_ids = batch['word_ids'].cuda()
                target = batch['targets'].cuda()

                ques_id = batch['question_ids']

                B = len(batch['word_ids'])

                results = self.model(
                    input_ids=input_ids,
                    visual_feats=vis_feats,
                    visual_pos=boxes,
                    attention_mask=input_ids > 0,
                )
                logit = results['logit']

                assert logit.size() == target.size()
                assert logit.size() == (B, self.num_answers)

                loss = self.bce_loss(logit, target)

                loss.backward()

                if update:
                    if not self.args.no_clip_grad:
                        nn.utils.clip_grad_norm_(self.model.parameters(),
                                                 self.args.clip_grad_norm)

                    self.optim.step()
                    self.lr_scheduler.step()
                    for param in self.model.parameters():
                        param.grad = None

                try:
                    lr = self.scheduler.get_last_lr()[0]
                except AttributeError:
                    lr = self.args.lr

                if self.verbose:
                    loss_meter.update(loss.item())
                    desc_str = f'Epoch {epoch} | LR {lr:.6f} | '
                    desc_str += f'Loss {loss_meter.val:4f} |'

                    score, predict = logit.max(1)
                    predict = predict.cpu().numpy()
                    target = target.cpu().numpy()

                    for qid, pred in zip(ques_id, predict):
                        pred_ans = self.train_loader.dataset.raw_dataset.label2ans[
                            pred]
                        quesid2ans[qid] = pred_ans

                    pbar.set_description(desc_str)
                    pbar.update(1)

                if self.args.distributed:
                    dist.barrier()

                # score, label = logit.max(1)
                # for qid, l in zip(ques_id, label.cpu().numpy()):
                #     ans = dset.label2ans[l]
                #     quesid2ans[qid.item()] = ans

            if self.verbose:
                pbar.close()
                score = self.train_loader.evaluator.evaluate(quesid2ans) * 100.
                log_str = "\nEpoch %d: Train %0.2f" % (epoch, score)

                if not self.args.dry:
                    self.writer.add_scalar(f'VQA/Train/score', score, epoch)

                # Validation
                valid_score = self.evaluate(self.val_loader) * 100.
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score)
                log_str += "\nEpoch %d: Best %0.2f\n" % (epoch, best_valid)

                if not self.args.dry:
                    self.writer.add_scalar(f'VQA/Valid/score', valid_score,
                                           epoch)

                print(log_str)
                self.logger.info(log_str)

            if self.args.distributed:
                dist.barrier()

        if self.verbose:
            self.save("LAST")

    def predict(self, loader, dump_path=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        with torch.no_grad():
            quesid2ans = {}
            for i, batch in enumerate(
                    tqdm(loader, ncols=150, desc="Prediction")):

                vis_feats = batch['vis_feats'].cuda()
                boxes = batch['boxes'].cuda()

                input_ids = batch['word_ids'].cuda()
                ques_id = batch['question_ids']

                results = self.model(
                    input_ids=input_ids,
                    visual_feats=vis_feats,
                    visual_pos=boxes,
                    attention_mask=input_ids > 0,
                )
                logit = results['logit']
                score, predict = logit.max(1)

                predict = predict.cpu().numpy()

                for qid, pred in zip(ques_id, predict):
                    pred_ans = loader.dataset.raw_dataset.label2ans[pred]
                    quesid2ans[qid] = pred_ans

            if dump_path is not None:
                loader.evaluator.dump_result(quesid2ans, dump_path)
            return quesid2ans

    def evaluate(self, loader, dump_path=None):
        evaluator = loader.evaluator
        quesid2ans = self.predict(loader, dump_path)
        return evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(loader):
        evaluator = loader.evaluator
        quesid2ans = {}
        for i, batch in enumerate(loader):

            ques_id = batch['question_ids']
            label = batch['targets']

            _, label = label.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = loader.dataset.raw_dataset.label2ans[l]
                quesid2ans[qid] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path, loc='cpu', verbose=False):
        state_dict = load_state_dict(path, loc)
        results = self.model.load_state_dict(state_dict, strict=False)
        if verbose:
            print('Loaded from ', path)
            print(results)
Exemple #11
0
    def __init__(self,
                 args,
                 train_loader=None,
                 val_loader=None,
                 logger=None,
                 num_answers=0,
                 train=True):
        self.args = args

        self.max_text_length = args.max_text_length

        self.train_loader = train_loader
        self.val_loader = val_loader
        self.num_answers = num_answers

        self.logger = logger

        # Model
        self.model = VQAModel.from_pretrained("bert-base-uncased",
                                              args=args,
                                              num_answers=self.num_answers)

        self.verbose = True
        if self.args.distributed:
            if self.args.gpu != 0:
                self.verbose = False

        # Load Checkpoint
        self.start_epoch = None
        if args.load is not None:
            path = args.load + '.pth'
            self.load(path, verbose=self.verbose)

        elif args.load_lxmert_qa is not None:
            path = args.load_lxmert_qa + '_LXRT.pth'
            load_lxmert_qa(
                args,
                path,
                self.model,
                label2ans=self.train_loader.dataset.raw_dataset.label2ans,
                verbose=self.verbose)

        # GPU Options
        print(f'Model Launching at GPU {self.args.gpu}')
        from time import time
        start = time()
        self.model.cuda(args.gpu)

        # Optimizer
        if train:
            self.optim, self.lr_scheduler = self.create_optimizer_and_scheduler(
            )
            self.bce_loss = nn.BCEWithLogitsLoss()

        if args.multiGPU:
            assert args.distributed
            self.model = DDP(self.model,
                             device_ids=[args.gpu],
                             find_unused_parameters=True)
        if args.gpu == 0:
            print(f'It took {time() - start:.1f}s')

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)
Exemple #12
0
class FastSCNNOperator(BaseOperator):
    def __init__(self, cfg):
        super(FastSCNNOperator, self).__init__(cfg)
        # prepare model for train
        self.model = Fast_SCNN(input_channel=3,
                               num_classes=self.cfg.Train.num_classes).cuda(
                                   self.cfg.Distributed.gpu_id)

        if self.cfg.Distributed.dist:
            self.model = DistributedDataParallel(
                self.model,
                find_unused_parameters=True,
                device_ids=[self.cfg.Distributed.gpu_id])

        self.optimizer = torch.optim.Adam(
            params=self.model.parameters(),
            lr=self.cfg.Train.learning_rate,
            weight_decay=self.cfg.Train.weight_decay)

        # self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.9, momentum=0.9)

        self.loss = torch.nn.CrossEntropyLoss(ignore_index=255)

        if self.cfg.is_training:
            self.loader = get_train_loader(name=self.cfg.Dataset.name,
                                           cfg=self.cfg)
        else:
            self.loader = get_val_loader(name=self.cfg.Dataset.name, cfg=cfg)
        self.cfg.Train.max_iter_num = self.cfg.Train.epochs * len(self.loader)

        self.main_flag = self.cfg.Distributed.gpu_id == 0

        if self.model.training:
            self.logger = Logger(self.cfg, self.main_flag)

    def adjust_lr(self, itr, max_itr):
        now_lr = self.cfg.Train.learning_rate * (
            1 - itr / (max_itr + 1))**self.cfg.Train.power
        self.optimizer.param_groups[0]['lr'] = now_lr
        # self.optimizer.param_groups[1]['lr'] = 10 * now_lr
        return now_lr

    def criterion(self, outs, labels):
        return self.loss(outs, labels)

    @staticmethod
    def save_ckp(models, step, path):
        """
        Save checkpoint of the model.
        :param models: nn.Module
        :param step: step of the checkpoint.
        :param path: save path.
        """
        torch.save(models.state_dict(),
                   os.path.join(path, 'ckp-{}.pth'.format(step)))

    def training_process(self):
        logger = Logger(self.cfg, self.main_flag)
        self.model.train()
        colormap = color_map[self.cfg.Dataset.name]

        itr = 0

        for epoch in range(self.cfg.Train.epochs):

            for idx, sample in enumerate(self.loader):

                newlr = self.adjust_lr(itr=itr,
                                       max_itr=self.cfg.Train.max_iter_num)

                xs, ys, names = sample

                ys = torch.squeeze(ys, dim=1)
                xs = xs.cuda(self.cfg.Distributed.gpu_id)
                ys = ys.cuda(self.cfg.Distributed.gpu_id)
                preds = self.model(xs)
                loss = self.criterion(preds, ys.long())
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                pred_img = seg_vis(preds, colormap)
                gt_img = seg_vis(ys, colormap)

                pred_img = torch.from_numpy(pred_img).permute(
                    2, 0, 1).unsqueeze(0).float() / 255.
                gt_img = torch.from_numpy(gt_img).permute(
                    2, 0, 1).unsqueeze(0).float() / 255.

                if self.main_flag:
                    if itr % self.cfg.Train.print_steps == self.cfg.Train.print_steps - 1:
                        log_data = {
                            'scalar': {
                                'Loss': loss.item()
                            },
                            'imgs': {
                                'Pred': [pred_img, gt_img]
                            }
                        }
                        logger.log(log_data, n_iter=itr)

                itr += 1

                if itr % self.cfg.Train.save_model == self.cfg.Train.save_model - 1:
                    if not os.path.exists(self.cfg.Train.ckp_dir):
                        os.mkdir(self.cfg.Train.ckp_dir)
                        if not os.path.exists(
                                os.path.join(self.cfg.Train.ckp_dir,
                                             self.cfg.Train.model_name)):
                            os.mkdir(
                                os.path.join(self.cfg.Train.ckp_dir,
                                             self.cfg.Train.model_name))
                    self.save_ckp(
                        self.model.module, itr,
                        os.path.join(self.cfg.Train.ckp_dir,
                                     self.cfg.Train.model_name))

    def eval_process(self):
        self.model.eval()
        self.model.module.load_state_dict(torch.load(self.cfg.Val.model_file))
        epoch = 0
        step = 0
        class_converter = np.array([
            7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31,
            32, 33
        ]).astype(np.uint8)

        self.loader.sampler.set_epoch(epoch)
        with torch.no_grad():
            for sample in self.loader:
                xs, ys, names = sample

                xs = xs.cuda(self.cfg.Distributed.gpu_id)
                outs = self.model(xs)
                pred_img = torch.max(outs, dim=1)[1].cpu()
                pred_img = pred_img.numpy().astype(np.uint8)

                print(pred_img.max())
                print(pred_img.shape)
                for i in range(pred_img.shape[0]):
                    img = class_converter[pred_img[i]]
                    # pred_img = class_converter[pred_img]
                    im = Image.fromarray(img)
                    img_dir = names[i].split(sep='/')[-2]
                    name = '_'.join(names[i].split(sep='/')[-1].split(
                        sep='_')[:-1]) + '.png'

                    if not os.path.exists(
                            os.path.join(self.cfg.Val.result_dir, img_dir)):
                        os.mkdir(os.path.join(self.cfg.Val.result_dir,
                                              img_dir))
                    im.save(
                        os.path.join(self.cfg.Val.result_dir, img_dir, name))
                if self.main_flag:
                    step += 1
                    print('Step : %d / %d' % (step, len(self.loader)))
            print('Done !!')
Exemple #13
0
    def __init__(
        self,
        model: Model,
        optimizer: torch.optim.Optimizer,
        data_loader: torch.utils.data.DataLoader,
        patience: Optional[int] = None,
        validation_metric: str = "-loss",
        validation_data_loader: torch.utils.data.DataLoader = None,
        num_epochs: int = 20,
        serialization_dir: Optional[str] = None,
        checkpointer: Checkpointer = None,
        cuda_device: int = -1,
        grad_norm: Optional[float] = None,
        grad_clipping: Optional[float] = None,
        learning_rate_scheduler: Optional[LearningRateScheduler] = None,
        momentum_scheduler: Optional[MomentumScheduler] = None,
        tensorboard_writer: TensorboardWriter = None,
        moving_average: Optional[MovingAverage] = None,
        batch_callbacks: List[BatchCallback] = None,
        epoch_callbacks: List[EpochCallback] = None,
        distributed: bool = False,
        local_rank: int = 0,
        world_size: int = 1,
        num_gradient_accumulation_steps: int = 1,
        opt_level: Optional[str] = None,
    ) -> None:
        super().__init__(serialization_dir, cuda_device, distributed,
                         local_rank, world_size)

        # I am not calling move_to_gpu here, because if the model is
        # not already on the GPU then the optimizer is going to be wrong.
        self.model = model

        self.data_loader = data_loader
        self._validation_data_loader = validation_data_loader
        self.optimizer = optimizer

        if patience is None:  # no early stopping
            if validation_data_loader:
                logger.warning(
                    "You provided a validation dataset but patience was set to None, "
                    "meaning that early stopping is disabled")
        elif (not isinstance(patience, int)) or patience <= 0:
            raise ConfigurationError(
                '{} is an invalid value for "patience": it must be a positive integer '
                "or None (if you want to disable early stopping)".format(
                    patience))

        # For tracking is_best_so_far and should_stop_early
        self._metric_tracker = MetricTracker(patience, validation_metric)
        # Get rid of + or -
        self._validation_metric = validation_metric[1:]

        self._num_epochs = num_epochs

        if checkpointer is not None:
            self._checkpointer = checkpointer
        else:
            self._checkpointer = Checkpointer(serialization_dir)

        self._grad_norm = grad_norm
        self._grad_clipping = grad_clipping

        self._learning_rate_scheduler = learning_rate_scheduler
        self._momentum_scheduler = momentum_scheduler
        self._moving_average = moving_average
        self._batch_callbacks = batch_callbacks or []
        self._epoch_callbacks = epoch_callbacks or []

        # We keep the total batch number as an instance variable because it
        # is used inside a closure for the hook which logs activations in
        # `_enable_activation_logging`.
        self._batch_num_total = 0

        self._tensorboard = tensorboard_writer or TensorboardWriter(
            serialization_dir)
        self._tensorboard.get_batch_num_total = lambda: self._batch_num_total
        self._tensorboard.enable_activation_logging(self.model)

        self._last_log = 0.0  # time of last logging

        self._num_gradient_accumulation_steps = num_gradient_accumulation_steps

        # Enable automatic mixed precision training with NVIDIA Apex.
        self._opt_level = opt_level
        if self._opt_level is not None:
            if amp is None:
                raise ConfigurationError((
                    "Apex not installed but opt_level was provided. Please install NVIDIA's Apex to enable"
                    " automatic mixed precision (AMP) training. See: https://github.com/NVIDIA/apex."
                ))

            self.model, self.optimizer = amp.initialize(
                self.model, self.optimizer, opt_level=self._opt_level)

        # Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its
        # usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model`
        # will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc.
        #
        # Hence a reference to Pytorch's object is maintained in the case of distributed training and in the
        # normal case, reference to `Model` is retained. This reference is only used in
        # these places: `model.__call__`, `model.train` and `model.eval`.
        if self._distributed:
            self._pytorch_model = DistributedDataParallel(
                self.model,
                device_ids=[self.cuda_device],
                find_unused_parameters=True)
        else:
            self._pytorch_model = self.model
Exemple #14
0
def train_ai2thor(model, args, rank=0, b=None):

    seed = args.seed + 10000 *rank
    torch.manual_seed(seed)
    np.random.seed(seed)

    # torch.cuda.set_device(rank)
    # device = torch.device(f'cuda:{rank}')
    device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
    # if torch.cuda.is_available():
    #     os.environ['DISPLAY'] = f':{rank}'

    model = model.to(device)
    model.share_memory()

    # Experience buffer
    storage = PPOBuffer(model.obs_shape, args.steps, args.num_workers, args.state_size, args.gamma, device=device)
    storage.share_memory()
        
    #torch.multiprocessing.set_start_method('spawn')
    # start multiple processes
    ready_to_works = [Event() for _ in range(args.num_workers)]
    exit_flag = Value('i', 0)
    queue = SimpleQueue()

    processes = []
    task_config_file = "config_files/NavTaskTrain.json"
    # start workers
    for worker_id in range(args.num_workers):
        p = Process(target=worker, args=(worker_id, model, storage, ready_to_works[worker_id], queue, exit_flag, task_config_file))
        p.start()
        processes.append(p)

    # start trainer
    train_params = {"epochs": args.epochs,
                    "steps": args.steps,
                    "world_size": args.world_size,
                    "num_workers": args.num_workers
                    }
    ppo_params = {"clip_param": args.clip_param,
                  "train_iters": args.train_iters,
                  "mini_batch_size": args.mini_batch_size,
                  "value_loss_coef": args.value_loss_coef,
                  "entropy_coef": args.entropy_coef,
                  "rnn_steps": args.rnn_steps,
                  "lr": args.lr,
                  "max_kl": args.max_kl
                  }
    
    distributed = False
    if args.world_size > 1:
        if distributed==True:
            distributed = True
            # Initialize Process Group, distributed backend type
            dist_backend = 'nccl'
            # Url used to setup distributed training
            dist_url = "tcp://127.0.0.1:23456"
            print("Initialize Process Group... pid:", os.getpid())
            dist.init_process_group(backend=dist_backend, init_method=dist_url, rank=rank, world_size=args.world_size)
            # Make model DistributedDataParallel
            model = DistributedDataParallel(model, device_ids=[rank], output_device=rank)
    else: print('Distribution Forbidden ~')
        
    learner(model, storage, train_params, ppo_params, ready_to_works, queue, exit_flag, rank, distributed, b)

    for p in processes:
        print("process ", p.pid, " joined")
        p.join()
Exemple #15
0
    def check_parity(amp: bool, manual_reduction: bool):

        # The API should be the exact same in between the sharded and non-sharded variants, generic closure
        def closure(model,
                    scaler,
                    input_tensor,
                    should_accumulate,
                    _manual_reduction=False):
            accumulate_steps = 3 if should_accumulate else 1

            model.zero_grad()

            def step():
                if scaler is not None:
                    with torch.cuda.amp.autocast():
                        loss = model(input_tensor).abs().sum()
                        scaler.scale(loss).backward()
                else:
                    loss = model(input_tensor).abs().sum()
                    loss.backward()

            with model.no_sync() if should_accumulate else suppress():
                for _ in range(accumulate_steps - 1):
                    step()

            if not _manual_reduction:
                step()
            else:
                with model.no_sync():
                    step()

                model.reduce()

        # Any model works. Add one different buffer per rank
        model = _get_mlp()
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)

        # Make sure that the model starts with non-trainable, so that we check for the buckets to be
        # properly reassigned when/if this changes
        next(model.parameters()).requires_grad = False

        sharded_optimizer = OSS(params=model.parameters(),
                                optim=torch.optim.SGD,
                                lr=1e-4,
                                momentum=0.99)
        sharded_ddp_model = ShardedDataParallel(
            module=model,
            sharded_optimizer=sharded_optimizer,
            broadcast_buffers=True,
            reduce_buffer_size=reduce_buffer_size,
            reduce_fp16=fp16_reduction,
        )

        ddp_model_single = copy.deepcopy(model)
        ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(),
                                        lr=1e-4,
                                        momentum=0.99)
        ddp_model = DDP(ddp_model_single,
                        device_ids=[rank],
                        broadcast_buffers=True,
                        find_unused_parameters=True)

        if fp16_reduction:
            from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook

            ddp_model.register_comm_hook(
                state=None, hook=fp16_compress_hook)  # type: ignore

        ddp_scaler = TorchGradScaler() if amp else None
        sharded_ddp_scaler = ShardedGradScaler() if amp else None

        # The model should be synchronized in between the ranks at construction time, check that
        check_same_model_params(sharded_ddp_model, ddp_model)

        # Typical training loop, check that we get the exact same results as DDP
        for i in range(NUMBER_BATCHS):
            input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)

            def closure_ddp(input_tensor=input_tensor):
                return closure(ddp_model, ddp_scaler, input_tensor,
                               grad_accumulation)

            def closure_sharded(input_tensor=input_tensor):
                return closure(
                    sharded_ddp_model,
                    sharded_ddp_scaler,
                    input_tensor,
                    grad_accumulation,
                    _manual_reduction=manual_reduction,
                )

            # Step/scale both
            if ddp_scaler is not None:
                _ = closure_ddp(input_tensor)
                ddp_scaler.step(ddp_optimizer)
                ddp_scaler.update()
            else:
                ddp_optimizer.step(closure=closure_ddp)

            if sharded_ddp_scaler is not None:
                _ = closure_sharded(input_tensor)
                sharded_ddp_scaler.step(sharded_optimizer)
                sharded_ddp_scaler.update()
            else:
                sharded_optimizer.step(closure=closure_sharded)

            check_same_model_params(sharded_ddp_model, ddp_model,
                                    f"Rank: {rank} - Step {i} broke")

            # Flip the trainability of the first parameter back and forth
            if i == 0 and change_train_graph:
                next(sharded_ddp_model.parameters()).requires_grad = not next(
                    sharded_ddp_model.parameters()).requires_grad
                next(ddp_model.parameters()).requires_grad = not next(
                    ddp_model.parameters()).requires_grad
                check_same_model_params(
                    sharded_ddp_model, ddp_model,
                    f"Rank: {rank} - Trainability refresh {i} broke")
Exemple #16
0
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # test_display_triplet_distance = False
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime())))
    opts = vars(args)
    keys = list(opts.keys())
    keys.sort()

    options = []
    for k in keys:
        options.append("\'%s\': \'%s\'" % (str(k), str(opts[k])))

    print('Parsed options: \n{ %s }' % (', '.join(options)))
    print('Number of Speakers: {}.\n'.format(train_dir.num_spks))

    # instantiate model and initialize weights
    kernel_size = args.kernel_size.split(',')
    kernel_size = [int(x) for x in kernel_size]

    context = args.context.split(',')
    context = [int(x) for x in context]
    if args.padding == '':
        padding = [int((x - 1) / 2) for x in kernel_size]
    else:
        padding = args.padding.split(',')
        padding = [int(x) for x in padding]

    kernel_size = tuple(kernel_size)
    padding = tuple(padding)
    stride = args.stride.split(',')
    stride = [int(x) for x in stride]

    channels = args.channels.split(',')
    channels = [int(x) for x in channels]

    dilation = args.dilation.split(',')
    dilation = [int(x) for x in dilation]
    model_kwargs = {
        'input_dim': args.input_dim,
        'feat_dim': args.feat_dim,
        'kernel_size': kernel_size,
        'context': context,
        'filter_fix': args.filter_fix,
        'dilation': dilation,
        'first_2d': args.first_2d,
        'mask': args.mask_layer,
        'mask_len': args.mask_len,
        'block_type': args.block_type,
        'filter': args.filter,
        'exp': args.exp,
        'inst_norm': args.inst_norm,
        'input_norm': args.input_norm,
        'stride': stride,
        'fast': args.fast,
        'avg_size': args.avg_size,
        'time_dim': args.time_dim,
        'padding': padding,
        'encoder_type': args.encoder_type,
        'vad': args.vad,
        'transform': args.transform,
        'embedding_size': args.embedding_size,
        'ince': args.inception,
        'resnet_size': args.resnet_size,
        'num_classes': train_dir.num_spks,
        'num_classes_b': train_dir.num_doms,
        'channels': channels,
        'alpha': args.alpha,
        'dropout_p': args.dropout_p,
        'loss_type': args.loss_type,
        'm': args.m,
        'margin': args.margin,
        's': args.s,
        'iteraion': 0,
        'all_iteraion': args.all_iteraion
    }

    print('Model options: {}'.format(model_kwargs))
    dist_type = 'cos' if args.cos_sim else 'l2'
    print('Testing with %s distance, ' % dist_type)

    model = create_model(args.model, **model_kwargs)

    start_epoch = 0
    if args.save_init and not args.finetune:
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path,
                                                   start_epoch)
        torch.save(model, check_path)

    iteration = 0  # if args.resume else 0
    if args.finetune and args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']

            checkpoint_state_dict = checkpoint['state_dict']
            if isinstance(checkpoint_state_dict, tuple):
                checkpoint_state_dict = checkpoint_state_dict[0]
            filtered = {
                k: v
                for k, v in checkpoint_state_dict.items()
                if 'num_batches_tracked' not in k
            }
            if list(filtered.keys())[0].startswith('module'):
                new_state_dict = OrderedDict()
                for k, v in filtered.items():
                    name = k[
                        7:]  # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module.
                    new_state_dict[name] = v  # 新字典的key值对应的value为一一对应的值。

                model.load_state_dict(new_state_dict)
            else:
                model_dict = model.state_dict()
                model_dict.update(filtered)
                model.load_state_dict(model_dict)
            # model.dropout.p = args.dropout_p
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    ce_criterion = nn.CrossEntropyLoss()
    if args.loss_type == 'soft':
        xe_criterion = None
    elif args.loss_type == 'asoft':
        ce_criterion = None
        xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min,
                                        lambda_max=args.lambda_max)
    elif args.loss_type == 'center':
        xe_criterion = CenterLoss(num_classes=train_dir.num_spks,
                                  feat_dim=args.embedding_size)
    elif args.loss_type == 'variance':
        xe_criterion = VarianceLoss(num_classes=train_dir.num_spks,
                                    feat_dim=args.embedding_size)
    elif args.loss_type == 'gaussian':
        xe_criterion = GaussianLoss(num_classes=train_dir.num_spks,
                                    feat_dim=args.embedding_size)
    elif args.loss_type == 'coscenter':
        xe_criterion = CenterCosLoss(num_classes=train_dir.num_spks,
                                     feat_dim=args.embedding_size)
    elif args.loss_type == 'mulcenter':
        xe_criterion = MultiCenterLoss(num_classes=train_dir.num_spks,
                                       feat_dim=args.embedding_size,
                                       num_center=args.num_center)
    elif args.loss_type == 'amsoft':
        ce_criterion = None
        xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s)
    elif args.loss_type == 'arcsoft':
        ce_criterion = None
        xe_criterion = ArcSoftmaxLoss(margin=args.margin,
                                      s=args.s,
                                      iteraion=iteration,
                                      all_iteraion=args.all_iteraion)
    elif args.loss_type == 'wasse':
        xe_criterion = Wasserstein_Loss(source_cls=args.source_cls)
    elif args.loss_type == 'ring':
        xe_criterion = RingLoss(ring=args.ring)
        args.alpha = 0.0

    model_para = [{'params': model.parameters()}]
    if args.loss_type in [
            'center', 'variance', 'mulcenter', 'gaussian', 'coscenter', 'ring'
    ]:
        assert args.lr_ratio > 0
        model_para.append({
            'params': xe_criterion.parameters(),
            'lr': args.lr * args.lr_ratio
        })

    if args.finetune or args.second_wd > 0:
        # if args.loss_type in ['asoft', 'amsoft']:
        classifier_params = list(map(id, model.classifier.parameters()))
        rest_params = filter(lambda p: id(p) not in classifier_params,
                             model.parameters())
        init_lr = args.lr * args.lr_ratio if args.lr_ratio > 0 else args.lr
        init_wd = args.second_wd if args.second_wd > 0 else args.weight_decay
        print('Set the lr and weight_decay of classifier to %f and %f' %
              (init_lr, init_wd))
        model_para = [{
            'params': rest_params
        }, {
            'params': model.classifier.parameters(),
            'lr': init_lr,
            'weight_decay': init_wd
        }]

    if args.filter in ['fDLR', 'fBLayer', 'fLLayer', 'fBPLayer']:
        filter_params = list(map(id, model.filter_layer.parameters()))
        rest_params = filter(lambda p: id(p) not in filter_params,
                             model_para[0]['params'])
        init_wd = args.filter_wd if args.filter_wd > 0 else args.weight_decay
        init_lr = args.lr * args.lr_ratio if args.lr_ratio > 0 else args.lr
        print('Set the lr and weight_decay of filter layer to %f and %f' %
              (init_lr, init_wd))
        model_para[0]['params'] = rest_params
        model_para.append({
            'params': model.filter_layer.parameters(),
            'lr': init_lr,
            'weight_decay': init_wd
        })

    optimizer = create_optimizer(model_para, args.optimizer, **opt_kwargs)

    if not args.finetune and args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']

            checkpoint_state_dict = checkpoint['state_dict']
            if isinstance(checkpoint_state_dict, tuple):
                checkpoint_state_dict = checkpoint_state_dict[0]
            filtered = {
                k: v
                for k, v in checkpoint_state_dict.items()
                if 'num_batches_tracked' not in k
            }

            # filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
            if list(filtered.keys())[0].startswith('module'):
                new_state_dict = OrderedDict()
                for k, v in filtered.items():
                    name = k[
                        7:]  # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module.
                    new_state_dict[name] = v  # 新字典的key值对应的value为一一对应的值。

                model.load_state_dict(new_state_dict)
            else:
                model_dict = model.state_dict()
                model_dict.update(filtered)
                model.load_state_dict(model_dict)
            # model.dropout.p = args.dropout_p
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    # Save model config txt
    with open(
            osp.join(
                args.check_path,
                'model.%s.conf' % time.strftime("%Y.%m.%d", time.localtime())),
            'w') as f:
        f.write('model: ' + str(model) + '\n')
        f.write('CrossEntropy: ' + str(ce_criterion) + '\n')
        f.write('Other Loss: ' + str(xe_criterion) + '\n')
        f.write('Optimizer: ' + str(optimizer) + '\n')

    milestones = args.milestones.split(',')
    milestones = [int(x) for x in milestones]
    milestones.sort()
    if args.scheduler == 'exp':
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=args.gamma)
    elif args.scheduler == 'rop':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   patience=args.patience,
                                                   min_lr=1e-5)
    elif args.scheduler == 'cyclic':
        cycle_momentum = False if args.optimizer == 'adam' else True
        scheduler = lr_scheduler.CyclicLR(optimizer,
                                          base_lr=1e-8,
                                          max_lr=args.lr,
                                          step_size_up=13000,
                                          cycle_momentum=cycle_momentum)
    else:
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             milestones=milestones,
                                             gamma=0.1)

    ce = [ce_criterion, xe_criterion]

    start = args.start_epoch + start_epoch
    print('Start epoch is : ' + str(start))
    # start = 0
    end = start + args.epochs

    if len(args.random_chunk
           ) == 2 and args.random_chunk[0] < args.random_chunk[1]:
        min_chunk_size = int(args.random_chunk[0])
        max_chunk_size = int(args.random_chunk[1])
        pad_dim = 2 if args.feat_format == 'kaldi' else 3

        train_loader = torch.utils.data.DataLoader(
            train_dir,
            batch_size=args.batch_size,
            collate_fn=PadCollate(
                dim=pad_dim,
                num_batch=int(np.ceil(len(train_dir) / args.batch_size)),
                min_chunk_size=min_chunk_size,
                max_chunk_size=max_chunk_size),
            shuffle=args.shuffle,
            **kwargs)
        valid_loader = torch.utils.data.DataLoader(
            valid_dir,
            batch_size=int(args.batch_size / 2),
            collate_fn=PadCollate(dim=pad_dim,
                                  fix_len=True,
                                  min_chunk_size=args.chunk_size,
                                  max_chunk_size=args.chunk_size + 1),
            shuffle=False,
            **kwargs)
    else:
        train_loader = torch.utils.data.DataLoader(train_dir,
                                                   batch_size=args.batch_size,
                                                   shuffle=args.shuffle,
                                                   **kwargs)
        valid_loader = torch.utils.data.DataLoader(valid_dir,
                                                   batch_size=int(
                                                       args.batch_size / 2),
                                                   shuffle=False,
                                                   **kwargs)
    train_extract_loader = torch.utils.data.DataLoader(train_extract_dir,
                                                       batch_size=1,
                                                       shuffle=False,
                                                       **extract_kwargs)

    if args.cuda:
        if len(args.gpu_id) > 1:
            print("Continue with gpu: %s ..." % str(args.gpu_id))
            torch.distributed.init_process_group(
                backend="nccl",
                # init_method='tcp://localhost:23456',
                init_method=
                'file:///home/ssd2020/yangwenhao/lstm_speaker_verification/data/sharedfile',
                rank=0,
                world_size=1)
            # if args.gain
            # model = DistributedDataParallel(model.cuda(), find_unused_parameters=True)
            model = DistributedDataParallel(model.cuda())

        else:
            model = model.cuda()

        for i in range(len(ce)):
            if ce[i] != None:
                ce[i] = ce[i].cuda()
        try:
            print('Dropout is {}.'.format(model.dropout_p))
        except:
            pass

    xvector_dir = args.check_path
    xvector_dir = xvector_dir.replace('checkpoint', 'xvector')
    start_time = time.time()

    try:
        for epoch in range(start, end):
            # pdb.set_trace()
            lr_string = '\n\33[1;34m Current \'{}\' learning rate is '.format(
                args.optimizer)
            for param_group in optimizer.param_groups:
                lr_string += '{:.10f} '.format(param_group['lr'])
            print('%s \33[0m' % lr_string)

            train(train_loader, model, ce, optimizer, epoch, scheduler)
            valid_loss = valid_class(valid_loader, model, ce, epoch)

            if (epoch == 1 or epoch !=
                (end - 2)) and (epoch % 4 == 1 or epoch in milestones
                                or epoch == (end - 1)):
                model.eval()
                check_path = '{}/checkpoint_{}.pth'.format(
                    args.check_path, epoch)
                model_state_dict = model.module.state_dict() \
                    if isinstance(model, DistributedDataParallel) else model.state_dict()
                torch.save(
                    {
                        'epoch': epoch,
                        'state_dict': model_state_dict,
                        'criterion': ce
                    }, check_path)

                valid_test(train_extract_loader, model, epoch, xvector_dir)
                test(model, epoch, writer, xvector_dir)
                if epoch != (end - 1):
                    try:
                        shutil.rmtree("%s/train/epoch_%s" %
                                      (xvector_dir, epoch))
                        shutil.rmtree("%s/test/epoch_%s" %
                                      (xvector_dir, epoch))
                    except Exception as e:
                        print('rm dir xvectors error:', e)

            if args.scheduler == 'rop':
                scheduler.step(valid_loss)
            elif args.scheduler == 'cyclic':
                continue
            else:
                scheduler.step()

    except KeyboardInterrupt:
        end = epoch

    writer.close()
    stop_time = time.time()
    t = float(stop_time - start_time)
    print("Running %.4f minutes for each epoch.\n" % (t / 60 /
                                                      (max(end - start, 1))))
    exit(0)
class nnUNetTrainerV2BraTSRegions_DDP(nnUNetTrainerV2_DDP):
    def __init__(self, plans_file, fold, local_rank, output_folder=None, dataset_directory=None, batch_dice=True,
                 stage=None,
                 unpack_data=True, deterministic=True, distribute_batch_size=False, fp16=False):
        super().__init__(plans_file, fold, local_rank, output_folder, dataset_directory, batch_dice, stage, unpack_data,
                         deterministic, distribute_batch_size, fp16)
        self.regions = get_brats_regions()
        self.regions_class_order = (1, 2, 3)
        self.loss = None
        self.ce_loss = nn.BCEWithLogitsLoss()

    def process_plans(self, plans):
        super().process_plans(plans)
        """
        The network has as many outputs as we have regions
        """
        self.num_classes = len(self.regions)

    def initialize_network(self):
        """inference_apply_nonlin to sigmoid"""
        super().initialize_network()
        self.network.inference_apply_nonlin = nn.Sigmoid()

    def initialize(self, training=True, force_load_plans=False):
        """
        this is a copy of nnUNetTrainerV2's initialize. We only add the regions to the data augmentation
        :param training:
        :param force_load_plans:
        :return:
        """
        if not self.was_initialized:
            maybe_mkdir_p(self.output_folder)

            if force_load_plans or (self.plans is None):
                self.load_plans_file()

            self.process_plans(self.plans)

            self.setup_DA_params()

            self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
                                                      "_stage%d" % self.stage)
            if training:
                self.dl_tr, self.dl_val = self.get_basic_generators()
                if self.unpack_data:
                    if self.local_rank == 0:
                        print("unpacking dataset")
                        unpack_dataset(self.folder_with_preprocessed_data)
                        print("done")
                    else:
                        # we need to wait until worker 0 has finished unpacking
                        npz_files = subfiles(self.folder_with_preprocessed_data, suffix=".npz", join=False)
                        case_ids = [i[:-4] for i in npz_files]
                        all_present = all(
                            [isfile(join(self.folder_with_preprocessed_data, i + ".npy")) for i in case_ids])
                        while not all_present:
                            print("worker", self.local_rank, "is waiting for unpacking")
                            sleep(3)
                            all_present = all(
                                [isfile(join(self.folder_with_preprocessed_data, i + ".npy")) for i in case_ids])
                        # there is some slight chance that there may arise some error because dataloader are loading a file
                        # that is still being written by worker 0. We ignore this for now an address it only if it becomes
                        # relevant
                        # (this can occur because while worker 0 writes the file is technically present so the other workers
                        # will proceed and eventually try to read it)
                else:
                    print(
                        "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                        "will wait all winter for your model to finish!")

                # setting weights for deep supervision losses
                net_numpool = len(self.net_num_pool_op_kernel_sizes)

                # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
                # this gives higher resolution outputs more weight in the loss
                weights = np.array([1 / (2 ** i) for i in range(net_numpool)])

                # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
                mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)])
                weights[~mask] = 0
                weights = weights / weights.sum()
                self.ds_loss_weights = weights

                seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads'))
                seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1))
                print("seeds train", seeds_train)
                print("seeds_val", seeds_val)
                self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val,
                                                                    self.data_aug_params[
                                                                        'patch_size_for_spatialtransform'],
                                                                    self.data_aug_params,
                                                                    deep_supervision_scales=self.deep_supervision_scales,
                                                                    seeds_train=seeds_train,
                                                                    seeds_val=seeds_val,
                                                                    pin_memory=self.pin_memory,
                                                                    regions=self.regions)
                self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
                                       also_print_to_console=False)
                self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
                                       also_print_to_console=False)
            else:
                pass

            self.initialize_network()
            self.initialize_optimizer_and_scheduler()
            self._maybe_init_amp()
            self.network = DDP(self.network, self.local_rank)

        else:
            self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
        self.was_initialized = True

    def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
                 step_size: int = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
                 validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
                 segmentation_export_kwargs: dict = None):
        super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size,
                               save_softmax=save_softmax, use_gaussian=use_gaussian,
                               overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug,
                               all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs)
        # run brats specific validation
        output_folder = join(self.output_folder, validation_folder_name)
        evaluate_regions(output_folder, self.gt_niftis_folder, self.regions)

    def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
        raise NotImplementedError("this class has not been changed to work with pytorch amp yet!")
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data, gpu_id=None)
            target = to_cuda(target, gpu_id=None)

        self.optimizer.zero_grad()

        output = self.network(data)
        del data

        total_loss = None

        for i in range(len(output)):
            # Starting here it gets spicy!
            axes = tuple(range(2, len(output[i].size())))

            # network does not do softmax. We need to do softmax for dice
            output_softmax = torch.sigmoid(output[i])

            # get the tp, fp and fn terms we need
            tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax, target[i], axes, mask=None)
            # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables
            # do_bg=False in nnUNetTrainer -> [:, 1:]
            nominator = 2 * tp[:, 1:]
            denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:]

            if self.batch_dice:
                # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice
                nominator = awesome_allgather_function.apply(nominator)
                denominator = awesome_allgather_function.apply(denominator)
                nominator = nominator.sum(0)
                denominator = denominator.sum(0)
            else:
                pass

            ce_loss = self.ce_loss(output[i], target[i])

            # we smooth by 1e-5 to penalize false positives if tp is 0
            dice_loss = (- (nominator + 1e-5) / (denominator + 1e-5)).mean()
            if total_loss is None:
                total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss)
            else:
                total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss)

        if run_online_evaluation:
            with torch.no_grad():
                output = output[0]
                target = target[0]
                out_sigmoid = torch.sigmoid(output)
                out_sigmoid = (out_sigmoid > 0.5).float()

                if self.threeD:
                    axes = (2, 3, 4)
                else:
                    axes = (2, 3)

                tp, fp, fn, _ = get_tp_fp_fn_tn(out_sigmoid, target, axes=axes)

                tp_hard = awesome_allgather_function.apply(tp)
                fp_hard = awesome_allgather_function.apply(fp)
                fn_hard = awesome_allgather_function.apply(fn)
                # print_if_rank0("after allgather", tp_hard.shape)

                # print_if_rank0("after sum", tp_hard.shape)

                self.run_online_evaluation(tp_hard.detach().cpu().numpy().sum(0),
                                           fp_hard.detach().cpu().numpy().sum(0),
                                           fn_hard.detach().cpu().numpy().sum(0))
        del target

        if do_backprop:
            if not self.fp16 or amp is None or not torch.cuda.is_available():
                total_loss.backward()
            else:
                with amp.scale_loss(total_loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            _ = clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()

        return total_loss.detach().cpu().numpy()

    def run_online_evaluation(self, tp, fp, fn):
        self.online_eval_foreground_dc.append(list((2 * tp) / (2 * tp + fp + fn + 1e-8)))
        self.online_eval_tp.append(list(tp))
        self.online_eval_fp.append(list(fp))
        self.online_eval_fn.append(list(fn))
Exemple #18
0
def train():
    args = get_args()
    '''Setup'''
    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path, exist_ok=True)
    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    # This is a logger.warning: it will be printed by all distributed processes
    logger.warning("Running process %d", args.local_rank)
    logger.info("Arguments: %s", pformat(args))
    '''Initialize distributed training if needed'''
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    model_class = VideoGPT2LMHeadModel
    model = model_class.from_pretrained(args.model_checkpoint)
    tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
    model.resize_token_embeddings(len(tokenizer))
    model.to(args.device)
    optimizer = AdamW(model.parameters(), lr=args.lr)
    '''
    Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    '''
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
        model = model.module

    logger.info("Prepare datasets")
    train_loader, val_loader = get_data_loaders_new(args, tokenizer)
    '''Training function and trainer'''
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, token_type_ids, labels, input_mask, i3d, video_mask, reply_mask = batch
        input_embs = model.transformer.wte(input_ids)
        video_embs = model.video_ff(i3d)
        input_embs = torch.cat([video_embs, input_embs], dim=1)
        token_type_ids = torch.cat([
            torch.ones((i3d.size(0), i3d.size(1))).long().cuda() *
            tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]), token_type_ids
        ],
                                   dim=1)
        video_loss = model(input_embs,
                           token_type_ids=token_type_ids,
                           labels=(labels, i3d),
                           attention_mask=[video_mask, input_mask],
                           mode="video")[0]
        reply_loss = model(input_embs,
                           token_type_ids=token_type_ids,
                           labels=(labels, i3d),
                           attention_mask=[reply_mask, input_mask],
                           mode="reply")[0]
        loss = (video_loss + reply_loss) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    '''Evaluation function and evaluator (evaluator output is the input of the metrics)'''

    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, token_type_ids, lm_labels, input_mask, i3d, video_mask, reply_mask = batch
            input_embs = model.transformer.wte(input_ids)
            video_embs = model.video_ff(i3d)
            input_embs = torch.cat([video_embs, input_embs], dim=1)
            token_type_ids = torch.cat([
                torch.ones((i3d.size(0), i3d.size(1))).long().cuda() *
                tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]),
                token_type_ids
            ],
                                       dim=1)
            model_outputs = model(input_embs,
                                  token_type_ids=token_type_ids,
                                  attention_mask=[reply_mask, input_mask])[0]

            lm_logits = model_outputs  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return lm_logits_flat_shifted, lm_labels_flat_shifted

    '''Engines'''
    trainer = Engine(update)
    evaluator = Engine(inference)
    '''
    Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    '''
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0], x[1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)
    '''
    On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    '''
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir="./tb_logs")
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(args.log_path,
                                             'checkpoint',
                                             n_saved=8,
                                             require_empty=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                                  checkpoint_handler,
                                  {'mymodel': getattr(model, 'module', model)})
        # "getattr" take care of distributed encapsulation

        torch.save(args, args.log_path + 'model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(args.log_path, CONFIG_NAME))
        tokenizer.save_vocabulary(args.log_path)
    '''Run the training'''
    trainer.run(train_loader, max_epochs=args.n_epochs)
    '''
    On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    '''
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        # TODO: PR in ignite to have better access to saved file paths (cleaner)
        os.rename(checkpoint_handler._saved[-1][1][-1],
                  os.path.join(args.log_path, WEIGHTS_NAME))
        tb_logger.close()
Exemple #19
0
def train(local_rank, args):
    torch.backends.cudnn.benchmark = True
    import os
    # torch.multiprocessing.set_sharing_strategy('file_system')
    # too many barriers / one node data parallel and multiple node DDP
    os.environ['MASTER_ADDR'] = args["master_addr"]
    os.environ['MASTER_PORT'] = args["master_port"]
    os.environ["NCCL_DEBUG"] = "WARN"
    # os.environ["CUDA_VISIBLE_DEVICES"] = str(local_rank)
    # gpu_device = 0
    gpu_device = local_rank
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    if args["wandb_dryrun"]:
        os.environ["WANDB_MODE"] = "dryrun"
        os.environ["WANDB_SILENT"] = "true"
    os.environ['TOKENIZERS_PARALLELISM'] = "true"
    torch.backends.cudnn.benchmark = True
    rank = args["nr"] if args["cpu"] else (args["nr"] * args["gpus_per_node"] + local_rank)
    nr = args["nr"]
    if args["cpu"]:
        assert local_rank == 0
        device = torch.device("cpu")
        args["dist_backend"] = "gloo"
        # init_method = "tcp://%s:%s" % ("127.0.0.1", "9999")
    else:
        device = torch.device(f'cuda:{gpu_device}')  # Unique only on individual node.
        torch.cuda.set_device(device)
    if args["init_method"] == "tcp":
        if args["nr"] == 0:
            args["master_addr"] = "0.0.0.0"
        init_method="tcp://%s:%s" % (args["master_addr"], args["master_port"])
    elif args["init_method"] == "file":
        init_method = 'file://%s/%s' % (args["master_addr"], args["master_port"])
    else:
        raise ValueError

    rnd = torch.tensor(0.0, device="cpu")
    if args["world_size"] > 1:
        dist.init_process_group(args["dist_backend"], rank=rank, world_size=args["world_size"], init_method=init_method)
        rnd = torch.tensor(int(time.time())).to(device)
        dist.broadcast(rnd, 0)
    barrier = get_barrier(args["world_size"] > 1)
    format = "%Y-%m-%d %H-%M %Z"
    # + timedelta(hours=5, minutes=30)
    time_string = (datetime.fromtimestamp(time.mktime(time.gmtime(rnd.cpu().item())))).astimezone(timezone('Asia/Kolkata')).strftime(format)
    ds_name = list(filter(lambda x: len(x.strip()) > 0, args["dataset"].split("/")))[-1].replace("train_fastformer_resampled_", "")
    set_seeds(args["seed"])
    batch_size = 8

    optimizer_config.lr = args["lr"]
    optimizer_config.weight_decay = args["weight_decay"]
    optimizer_config.gradient_clipping = args["gradient_clipping"]
    optimizer_config.beta_1 = args["beta_1"]
    optimizer_config.beta_2 = args["beta_2"]

    eps = 1e-4
    if args["no_autocast"]:
        optimizer_config.eps = 1e-7
        eps = 1e-7

    reinit = args["pretrained_model"] is None or "pretrained_model" not in args or args["pretrained_model"] == ""
    backbone, tokenizer = get_mtt_backbone(args["model_config"], args["cls_tokens"], args["enable_layer_normalizers"], args["sampling_alpha"], reinit,
                                           args["enable_layer_normalizers"], args["enable_layer_normalizers_statistics"], dropout_prob=0.01)
    teacher_backbone, _ = get_mtt_backbone(args["model_config"], args["cls_tokens"], args["enable_layer_normalizers"], None, reinit, args["enable_layer_normalizers"],
                                           args["enable_layer_normalizers_statistics"], dropout_prob=0.0)

    batch_size = args["batch_size"] if "batch_size" in args and isinstance(args["batch_size"], int) else batch_size
    generator_w = args["generator_w"] if "generator_w" in args else 0.0
    discriminator_w = args["discriminator_w"] if "discriminator_w" in args else 0.0
    dino_w = args["dino_w"] if "dino_w" in args else 0.0
    sentence_order_prediction_w = args["sentence_order_prediction_w"] if "sentence_order_prediction_w" in args else 0.0
    attention_penalty_w = args["attention_penalty_w"] if "attention_penalty_w" in args else 0.0

    student = MTTModel(backbone, tokenizer, args["cls_tokens"],
                       generator_w=generator_w, discriminator_w=discriminator_w,
                       dino_w=dino_w, sentence_order_prediction_w=sentence_order_prediction_w, attention_penalty_w=attention_penalty_w,
                       lm_layers=args["lm_layers"], electra_layers=args["electra_layers"],
                       lm_layers_total=args["lm_layers_total"], electra_layers_total=args["electra_layers_total"],
                       drop_unused_layers=args["drop_unused_layers"], approximate_unused_layers=args["consecutive_layers"],
                       exclude_layers=args["exclude_layers"], keep_last_layer=args["keep_last_layer"],
                       lm_temperature=args["lm_temperature"])
    teacher = MTTModel(teacher_backbone, tokenizer, args["cls_tokens"],
                       generator_w=generator_w, discriminator_w=discriminator_w,
                       dino_w=1.0, sentence_order_prediction_w=sentence_order_prediction_w, attention_penalty_w=0.0,
                       lm_layers=None, electra_layers=None,
                       lm_layers_total=args["lm_layers_total"], electra_layers_total=args["electra_layers_total"],
                       lm_temperature=args["lm_temperature"])
    teacher = teacher.eval()
    model = MultiTaskHighwayCLSPretraining(student, teacher, eps, device if args["move_unused_to_cpu"] else None).to(device)
    trainable_model = student

    if dino_w == 0:
        model.teacher = None
        teacher = None
        clean_memory()
    del teacher
    if local_rank == 0 and rank == 0:
        print("[Train]: Time = %s, Trainable Params = %s" % (get_time_string(), numel(trainable_model) / 1_000_000))

    if args["pretrained_model"] is not None and os.path.exists(args["pretrained_model"]):
        state_dict = torch.load(args["pretrained_model"], map_location='cpu' if args['cpu'] else 'cuda:%d' % gpu_device)

        try:
            trainable_model.load_state_dict(state_dict, strict=True)
            load_type = "strict"
        except:
            try:
                try:
                    trainable_model.load_state_dict(state_dict, strict=False)
                    load_type = "not_strict"
                except:
                    state_dict = {k: v for k, v in state_dict.items() if k.startswith("backbone.")}
                    trainable_model.load_state_dict(state_dict, strict=False)
                    load_type = "not_strict_no_ffn"
            except:
                try:
                    try:
                        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
                        trainable_model.load_state_dict(state_dict, strict=True)
                        load_type = "strict-from-ddp"
                    except:
                        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
                        state_dict = {k: v for k, v in state_dict.items() if not k.startswith("backbone.")}
                        trainable_model.load_state_dict(state_dict, strict=True)
                        load_type = "strict-from-ddp-no-ffn"
                except:
                    try:
                        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
                        trainable_model.load_state_dict(state_dict, strict=False)
                        load_type = "not_strict-from-ddp"
                    except:
                        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
                        state_dict = {k: v for k, v in state_dict.items() if not k.startswith("backbone.")}
                        trainable_model.load_state_dict(state_dict, strict=False)
                        load_type = "not_strict-from-ddp-no-ffn"
        if dino_w > 0:
            student_teacher_param_update(model.student, model.teacher, 0.001, device if args["move_unused_to_cpu"] else None)

        print("[Train]: Time = %s, Loaded Pretrained model with Load type = %s, Torch Version = %s" % (get_time_string(), load_type, torch.__version__))
        del state_dict
    model = model.train()

    # print("[Train]: Time = %s, Trainable Params = %s" % (get_time_string(), {k for k, v in model.named_parameters() if v.requires_grad}))
    if args["world_size"] > 1:
        # model = FSDP(model, **fsdp_params)  # find_unused_parameters=True

        trainable_model = DDP(trainable_model, device_ids=None if args["cpu"] else [gpu_device], find_unused_parameters=True, bucket_cap_mb=50)  # find_unused_parameters=True
        model.student = trainable_model

    if dino_w > 0:
        model.teacher = model.teacher.eval()
        student_teacher_param_update(model.student, model.teacher, 0.01, device if args["move_unused_to_cpu"] else None)
    try:
        from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import fp16_compress_hook
        trainable_model.register_comm_hook(state=None, hook=fp16_compress_hook)

    except:
        print("[Train]: Time = %s, No fp16_compress_hook present, Torch Version = %s" % (get_time_string(), torch.__version__))

    del backbone
    del teacher_backbone
    del student
    clean_memory()
    barrier()
    optc = optimizer_config.to_dict()
    trainable_params = list(filter(lambda p: p.requires_grad, trainable_model.parameters()))
    if args["optimizer"] == "adamw":
        optimizer = torch.optim.AdamW(trainable_params, lr=optc["lr"], eps=optc["eps"], weight_decay=optc["weight_decay"],
                                      betas=(optc["beta_1"], optc["beta_2"]))
    elif args["optimizer"] == "sgd":
        optimizer = torch.optim.SGD(trainable_params, lr=optc["lr"], momentum=0.9, weight_decay=optc["weight_decay"], nesterov=True)
    elif args["optimizer"] == "novograd":
        optimizer = Novograd(trainable_params, lr=optc["lr"], eps=optc["eps"], betas=(optc["beta_1"], optc["beta_2"]), weight_decay=optc["weight_decay"],)
    elif args["optimizer"] == "rangerlars":
        optimizer = RangerLars(trainable_params, lr=optc["lr"], eps=optc["eps"], betas=(optc["beta_1"], optc["beta_2"]), weight_decay=optc["weight_decay"],)
    else:
        raise ValueError
    # print("[Train]: Time = %s, Trainable Params = %s" % (get_time_string(), {k for k, v in trainable_model.named_parameters() if v.requires_grad}))
    del trainable_params
    optimizer.zero_grad(set_to_none=True)
    model_save_dir = args["model_save_dir"]
    model_save_name = args["model_save_name"]

    set_seeds(args["seed"] + rank)
    if local_rank == 0:
        if not os.path.exists(model_save_dir):
            os.makedirs(model_save_dir)
        assert os.path.exists(model_save_dir)

    try:
        dataloader = build_dataloader(os.path.join(args["dataset"], "all_512_only"), args["shuffle_dataset"], batch_size,
                                      tokenizer, args["cls_tokens"],
                                      world_size=args["world_size"], num_workers=args["num_workers"], max_length=512)
        dataloader128 = build_dataloader(os.path.join(args["dataset"], "all_128_only"), args["shuffle_dataset"], batch_size * 6,
                                      tokenizer, args["cls_tokens"],
                                      world_size=args["world_size"], num_workers=args["num_workers"], max_length=128)
        dataloader256 = build_dataloader(os.path.join(args["dataset"], "all_256_only"), args["shuffle_dataset"], batch_size * 3,
                                      tokenizer, args["cls_tokens"],
                                      world_size=args["world_size"], num_workers=args["num_workers"], max_length=256)
    except:
        print("[WARN] [Train]: Time = %s, All dataloaders and datasets are same = %s" % (get_time_string(), args["dataset"]))
        dataloader = build_dataloader(args["dataset"], args["shuffle_dataset"], batch_size,
                                      tokenizer, args["cls_tokens"],
                                      world_size=args["world_size"], num_workers=args["num_workers"], max_length=512)
        dataloader128 = build_dataloader(args["dataset"], args["shuffle_dataset"], batch_size * 4,
                                         tokenizer, args["cls_tokens"],
                                         world_size=args["world_size"], num_workers=args["num_workers"], max_length=128)
        dataloader256 = build_dataloader(args["dataset"], args["shuffle_dataset"], batch_size * 2,
                                         tokenizer, args["cls_tokens"],
                                         world_size=args["world_size"], num_workers=args["num_workers"], max_length=256)

    iter_size = max(args["accumulation_steps"], 1)
    no_sync = iter_size > 1
    steps_per_epoch = int(np.ceil(len(dataloader.sampler) / (batch_size * iter_size)) if dataloader.sampler is not None else (len(dataloader) / iter_size))
    if local_rank == 0:
        print("[Train]: Time = %s, Optimizer and Scheduler Initialised, max lr = %.5f, steps_per_epoch = %s, batch size = %s, dataloader length = %s, Sampler Present = %s, Sampler Length = %s" %
              (get_time_string(), optc["lr"], steps_per_epoch, batch_size, len(dataloader), dataloader.sampler is not None, len(dataloader.sampler) if dataloader.sampler is not None else -1))

    dataloader = get_next(dataloader)
    dataloader128 = get_next(dataloader128)
    dataloader256 = get_next(dataloader256)
    log_every_steps = args["log_every_steps"] * iter_size
    save_every_steps = args["save_every_steps"]

    # scheduler = optimization.get_constant_schedule_with_warmup(optimizer, optc["warmup_steps"])
    # scheduler = optimization.get_linear_schedule_with_warmup(optimizer, optc["warmup_steps"], args["epochs"] * len(dataloader))
    div_factor = optc["lr"]/1e-6
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, optc["lr"], total_steps=args["total_steps"],
                                                    div_factor=div_factor, three_phase=False, pct_start=0.06, anneal_strategy="linear", cycle_momentum=False)

    # scheduler1 = optimization.get_constant_schedule_with_warmup(optimizer, optc["warmup_steps"])
    # scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=(steps_per_epoch * args["epochs"]) // args["lr_steps"], gamma=0.5)
    # scheduler = [scheduler1, scheduler2]

    barrier()

    gradient_clipping = optc["gradient_clipping"]

    group = "%s-%s-%s-%sN-%s" % (args["wandb_name"], ds_name, args["model_config"], args["nodes"], time_string)
    wandb_init_args = dict(project="de_lm", name="%s-%s-%s-%s" % (group, args["nr"], rank, local_rank), group=group, id=f"{group}-worker-{nr}-{rank}-{local_rank}",
                           config={"args":args, "optimizer_config": optc},
                           settings=wandb.Settings(start_method="fork"))

    time.sleep(random.random())
    wandb.init(**wandb_init_args)

    full_times = []
    batch_times = []
    model_times = []
    model.zero_grad(set_to_none=True)
    samples_processed = 0
    samples_processed_this_log_iter = 0
    if args["detect_anomaly"]:
        torch.autograd.set_detect_anomaly(True)

    def hook(grad):
        is_nan_inf = torch.logical_not(torch.isfinite(grad))
        if is_nan_inf.any():
            # print("[GRAD-HOOK]: Time = %s, Param Name = %s, Detected Nan/Inf" % (get_time_string(), name_of_param))
            grad[is_nan_inf] = 0.0
            return grad
        return None

    if not args["no_autocast"] and args["backward_hook"]:
        for name, param in model.named_parameters():
            param.register_hook(hook)


    dino_center = None
    discriminator_dino_center = None
    total_steps = args["total_steps"]
    steps_done = 0
    step = 0

    start_time = time.time()
    while steps_done < total_steps:
        random.seed(step)
        len_proba = random.random()
        if len_proba < 0.5:
            batch = dataloader128()
        elif len_proba < 0.6:
            batch = dataloader256()
        else:
            batch = dataloader()

        epoch_128 = dataloader128.epoch
        epoch_256 = dataloader256.epoch
        epoch_512 = dataloader.epoch

        # batch = None
        # if len_proba < 0.9:
        #     batches = [dataloader128() for _ in range(4)]
        # elif len_proba < 0.97:
        #     batches = [dataloader256() for _ in range(2)]
        # else:
        #     batch = dataloader()
        #
        # if batch is None:
        #     keys = batches[0].keys()
        #     batch = dict()
        #     for k in keys:
        #         elems = [b[k] for b in batches]
        #         if isinstance(elems[0], (list, tuple)):
        #             new_elems = [i for e in elems for i in e]
        #         elif isinstance(elems[0], torch.Tensor):
        #             new_elems = torch.cat(elems, 0)
        #         else:
        #             raise TypeError("Expected List or Tensor")
        #         batch[k] = new_elems

        key = list(batch.keys())[0]
        bs_size = list(batch[key].size())
        batch = {k: v.to(device, non_blocking=True) if hasattr(v, "to") else v for k, v in batch.items()}

        gen_batch_time = time.time() - start_time

        teacher_update_w = np.interp(steps_done, [0, args["teacher_warmup_steps"]], [0.95, 0.999])
        inner_model = getattr(trainable_model, "module", trainable_model)
        if hasattr(inner_model, "start_from_proba"):
            start_from_proba = np.interp(steps_done, [0, args["warmup_steps"], args["warmup_steps"] * 2], [0.0, 0.0, args["start_from_proba"]])
            inner_model.start_from_proba = start_from_proba
        if hasattr(inner_model.backbone.encoder, "sampling_alpha") and args["sampling_alpha"] is not None and args["sampling_alpha"] != 1.0:
            sampling_alpha = np.interp(steps_done, [0, args["warmup_steps"], args["warmup_steps"] * 2], [1.0, 1.0, args["sampling_alpha"]])
            inner_model.backbone.encoder.sampling_alpha = max(sampling_alpha, 0.01)
            inner_model.sampling_alpha = max(sampling_alpha, 0.01)
        if args["dino_w"] > 0:
            dino_w = np.interp(steps_done, [0, args["teacher_warmup_steps"], args["teacher_warmup_steps"] * 2], [0.0, 0.0, args["dino_w"]])
            inner_model.dino_w = dino_w
        lm_temperature = np.interp(steps_done, [0, args["warmup_steps"], args["warmup_steps"] * 2],
                                   [args["lm_temperature"], args["lm_temperature"], args["lm_temperature"] + 1.0])
        inner_model.lm_temperature = lm_temperature

        batch_times.append(gen_batch_time)
        if (steps_done + 1) % save_every_steps == 0 or (args["total_steps"] is not None and (steps_done + 1) >= args["total_steps"]):
            state_dict = trainable_model.state_dict() if not isinstance(trainable_model, DDP) else trainable_model.module.state_dict()
            if local_rank == 0:
                torch.save(state_dict, os.path.join(model_save_dir, model_save_name))
            del state_dict
            clean_memory()
            barrier()
            if args["total_steps"] is not None and (steps_done + 1) >= args["total_steps"]:
                return

        samples_processed += int(batch[key].size(0))
        samples_processed_this_log_iter += int(batch[key].size(0))
        inner_args = dict(no_autocast=args["no_autocast"], cpu=args["cpu"])
        validation_iter = (step + 1) % log_every_steps == 0 or step == 0
        model_start = time.time()
        if no_sync and (step + 1) % iter_size != 0 and hasattr(trainable_model, "no_sync"):
            with trainable_model.no_sync():
                output = train_inner_loop(inner_args, model, batch, optimizer,
                                          scheduler, gradient_clipping, iter_size=iter_size,
                                          no_sync=True, validation_iter=validation_iter,
                                          dino_center=dino_center,
                                          discriminator_dino_center=discriminator_dino_center,
                                          freeze_last_layer=steps_done < args["freeze_last_layer"], step=step + 1)
            model_times.append(time.time() - model_start)
        else:
            output = train_inner_loop(inner_args, model, batch, optimizer,
                                      scheduler, gradient_clipping, iter_size=iter_size,
                                      no_sync=False, validation_iter=validation_iter,
                                      dino_center=dino_center,
                                      discriminator_dino_center=discriminator_dino_center,
                                      freeze_last_layer=steps_done < args["freeze_last_layer"], step=step + 1)
            optimizer.zero_grad(set_to_none=True)
            steps_done += 1
            model_times.append(time.time() - model_start)

        step += 1
        del batch
        if dino_w > 0 and (step + 1) % iter_size:
            student_teacher_param_update(model.student, model.teacher, teacher_update_w, device if args["move_unused_to_cpu"] else None)
        dino_center = output.pop("dino_center", None)
        discriminator_dino_center = output.pop("discriminator_dino_center", None)
        # if dino_w > 0 and (step + 1) % (1 * iter_size) == 0 and args["world_size"] > 1:
        #     if dino_center is not None:
        #         dtype = dino_center.dtype
        #         dino_center = dino_center.type(torch.float64) / args["world_size"]
        #         torch.distributed.all_reduce(dino_center, torch.distributed.ReduceOp.SUM)
        #         dino_center = dino_center.type(dtype)
        #     if discriminator_dino_center is not None:
        #         dtype = discriminator_dino_center.dtype
        #         discriminator_dino_center = discriminator_dino_center.type(torch.float64) / args["world_size"]
        #         torch.distributed.all_reduce(discriminator_dino_center, torch.distributed.ReduceOp.SUM)
        #         discriminator_dino_center = discriminator_dino_center.type(dtype)
        if (step + 1) % (4 * iter_size) == 0 and hasattr(getattr(trainable_model, "module", trainable_model).backbone, "layer_normalizers") and args["world_size"] > 1:
            layer_normalizers = getattr(trainable_model, "module", trainable_model).backbone.layer_normalizers
            if layer_normalizers is not None:
                dtype = layer_normalizers.dtype
                layer_normalizers = layer_normalizers.type(torch.float64)
                torch.distributed.all_reduce(layer_normalizers, torch.distributed.ReduceOp.SUM)
                layer_normalizers = layer_normalizers / args["world_size"]
                getattr(trainable_model, "module", trainable_model).backbone.layer_normalizers = layer_normalizers.type(dtype)

        if (step + 1) % (4 * iter_size) == 0 and hasattr(getattr(trainable_model, "module", trainable_model).backbone, "layer_normalizers_small") and args["world_size"] > 1:
            layer_normalizers_small = getattr(trainable_model, "module", trainable_model).backbone.layer_normalizers_small
            if layer_normalizers_small is not None:
                dtype = layer_normalizers_small.dtype
                layer_normalizers_small = layer_normalizers_small.type(torch.float64)
                torch.distributed.all_reduce(layer_normalizers_small, torch.distributed.ReduceOp.SUM)
                layer_normalizers_small = layer_normalizers_small / args["world_size"]
                getattr(trainable_model, "module", trainable_model).backbone.layer_normalizers_small = layer_normalizers_small.type(dtype)

        full_time = time.time() - start_time
        full_times.append(full_time)
        if step == 0 and local_rank == 0:
            print("[Train]: Time = %s, First Batch Training for Rank = %s" % (get_time_string(), rank))
        if validation_iter:
            steps_remaining = total_steps - steps_done
            # print({k for k, v in output.items() if isinstance(v, torch.Tensor)})
            output = {k: float(v) if v else v for k, v in output.items()}
            samples_per_second = samples_processed_this_log_iter / np.sum(full_times)
            wandb_log = dict(lr=optimizer.param_groups[0]['lr'], step=step, samples_processed=samples_processed, samples_per_second=samples_per_second,
                             batch_times=np.mean(batch_times), full_times=np.mean(full_times), model_times=np.mean(model_times),
                             steps_remaining=steps_remaining, pct_complete=(100 * steps_done / total_steps),
                             epoch_128=epoch_128, epoch_256=epoch_256, epoch_512=epoch_512,
                             **{k: v for k, v in output.items() if v is not None})

            wandb.log(wandb_log)
            if local_rank == 0:
                print("[Train]: Time = %s, Rank = %s, steps = %s, samples_processed=%s, batch_size = %s, Details = %s, LR = %s" %
                      (get_time_string(), rank, step, samples_processed, bs_size, output, optimizer.param_groups[0]['lr']))
                print("[Train-Timings]: Time = %s, Batch time = %.4f, Full Time = %.4f, Model Time = %.4f, samples_per_second = %s, steps_remaining = %s, pct_complete = %.4f" % (
                    get_time_string(), np.mean(batch_times), np.mean(full_times), np.mean(model_times), samples_per_second, steps_remaining, (100 * steps_done / total_steps),))
                # print("Step = %s, Steps Done = %s, log_every_steps = %s, total_steps = %s, steps_remaining = %s, validation_iter = %s, %s" % (step, steps_done, log_every_steps, total_steps, steps_remaining, validation_iter, (step + 1) % log_every_steps == 0))
            batch_times = []
            full_times = []
            model_times = []
            samples_processed_this_log_iter = 0
            if args["enable_layer_normalizers_statistics"] and local_rank == 0:
                backbone = getattr(model.student, "module", model.student).backbone
                stats = backbone.layer_normalizers
                inp_stats=backbone.encoder.layer_normalizers_statistics
                norms = stats[:, 2, 0].tolist()
                inp_norms = inp_stats[:, 2, 0].tolist()
                centers = stats[:, 0, 0:8].tolist()
                inp_centers = inp_stats[:, 0, 0:8].tolist()
                stds = stats[:, 1, 0:8].tolist()
                inp_stds = inp_stats[:, 1, 0:8].tolist()

                dist_stats = backbone.encoder.distance_statistics.tolist()

                print("Branch Norms = \n", tabulate(pd.DataFrame(norms), tablefmt="psql"))
                print("Skip Norms = \n", tabulate(pd.DataFrame({"norm": inp_norms, "dist": dist_stats}), tablefmt="psql"))

                print("Branch centers = \n", tabulate(pd.DataFrame(centers), tablefmt="psql"))
                print("Skip centers = \n", tabulate(pd.DataFrame(inp_centers), tablefmt="psql"))

                print("Branch stds = \n", tabulate(pd.DataFrame(stds), tablefmt="psql"))
                print("Skip stds = \n", tabulate(pd.DataFrame(inp_stds), tablefmt="psql"))


            # clean_memory()
            # barrier()
        del output
        del bs_size
        start_time = time.time()
    print("Time = %s, Finished Training for Rank = %s" % (get_time_string(), rank))
    state_dict = trainable_model.state_dict() if not isinstance(trainable_model, DDP) else trainable_model.module.state_dict()
    if local_rank == 0:
        torch.save(state_dict, os.path.join(model_save_dir, model_save_name))
    del model
Exemple #20
0
def main(config):
    train_loader, val_loader = get_loader(config)
    n_data = len(train_loader.dataset)
    logger.info(f"length of training dataset: {n_data}")
    n_data = len(val_loader.dataset)
    logger.info(f"length of validation dataset: {n_data}")

    model, criterion = build_scene_segmentation(config)
    model.cuda()
    criterion.cuda()

    if config.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=config.batch_size *
                                    dist.get_world_size() / 8 *
                                    config.base_learning_rate,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay)
    elif config.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config.batch_size *
                                     dist.get_world_size() / 8 *
                                     config.base_learning_rate,
                                     weight_decay=config.weight_decay)
    elif config.optimizer == 'adamW':
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=config.batch_size *
                                      dist.get_world_size() / 8 *
                                      config.base_learning_rate,
                                      weight_decay=config.weight_decay)
    else:
        raise NotImplementedError(
            f"Optimizer {config.optimizer} not supported")

    scheduler = get_scheduler(optimizer, len(train_loader), config)

    model = DistributedDataParallel(model,
                                    device_ids=[config.local_rank],
                                    broadcast_buffers=False)

    runing_vote_logits = [
        np.zeros((config.num_classes, l.shape[0]), dtype=np.float32)
        for l in val_loader.dataset.sub_clouds_points_labels
    ]

    # optionally resume from a checkpoint
    if config.load_path:
        assert os.path.isfile(config.load_path)
        load_checkpoint(config, model, optimizer, scheduler)
        logger.info("==> checking loaded ckpt")
        validate('resume',
                 val_loader,
                 model,
                 criterion,
                 runing_vote_logits,
                 config,
                 num_votes=2)

    # tensorboard
    if dist.get_rank() == 0:
        summary_writer = SummaryWriter(log_dir=config.log_dir)
    else:
        summary_writer = None

    # routine
    for epoch in range(config.start_epoch, config.epochs + 1):
        train_loader.sampler.set_epoch(epoch)
        val_loader.sampler.set_epoch(epoch)
        train_loader.dataset.epoch = epoch - 1
        tic = time.time()
        loss = train(epoch, train_loader, model, criterion, optimizer,
                     scheduler, config)

        logger.info('epoch {}, total time {:.2f}, lr {:.5f}'.format(
            epoch, (time.time() - tic), optimizer.param_groups[0]['lr']))
        if epoch % config.val_freq == 0:
            validate(epoch,
                     val_loader,
                     model,
                     criterion,
                     runing_vote_logits,
                     config,
                     num_votes=2)

        if dist.get_rank() == 0:
            # save model
            save_checkpoint(config, epoch, model, optimizer, scheduler)

        if summary_writer is not None:
            # tensorboard logger
            summary_writer.add_scalar('ins_loss', loss, epoch)
            summary_writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'], epoch)

    validate('Last',
             val_loader,
             model,
             criterion,
             runing_vote_logits,
             config,
             num_votes=20)
Exemple #21
0
def run(max_epochs=None,
        device=None,
        batch_size=24,
        max_sequence_length=128,
        random_sequence_length=False,
        epoch_size=None,
        seed=None,
        data_dir='data',
        real_dataset='webtext',
        fake_dataset='xl-1542M-nucleus',
        token_dropout=None,
        large=False,
        learning_rate=2e-5,
        weight_decay=0,
        **kwargs):
    args = locals()
    rank, world_size = setup_distributed()

    if device is None:
        device = f'cuda:{rank}' if torch.cuda.is_available() else 'cpu'

    print('rank:', rank, 'world_size:', world_size, 'device:', device)

    import torch.distributed as dist
    if distributed() and rank > 0:
        dist.barrier()

    model_name = 'roberta-large' if large else 'roberta-base'
    tokenization_utils.logger.setLevel('ERROR')
    tokenizer = RobertaTokenizer.from_pretrained(model_name)
    model = RobertaForSequenceClassification.from_pretrained(model_name).to(
        device)

    if rank == 0:
        summary(model)
        if distributed():
            dist.barrier()

    if world_size > 1:
        model = DistributedDataParallel(model, [rank],
                                        output_device=rank,
                                        find_unused_parameters=True)

    train_loader, validation_loader = load_datasets(
        data_dir, real_dataset, fake_dataset, tokenizer, batch_size,
        max_sequence_length, random_sequence_length, epoch_size, token_dropout,
        seed)

    optimizer = Adam(model.parameters(),
                     lr=learning_rate,
                     weight_decay=weight_decay)
    epoch_loop = count(1) if max_epochs is None else range(1, max_epochs + 1)

    logdir = os.environ.get("OPENAI_LOGDIR", "logs")
    os.makedirs(logdir, exist_ok=True)

    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(logdir) if rank == 0 else None
    best_validation_accuracy = 0

    for epoch in epoch_loop:
        if world_size > 1:
            train_loader.sampler.set_epoch(epoch)
            validation_loader.sampler.set_epoch(epoch)

        train_metrics = train(model, optimizer, device, train_loader,
                              f'Epoch {epoch}')
        validation_metrics = validate(model, device, validation_loader)

        combined_metrics = _all_reduce_dict(
            {
                **validation_metrics,
                **train_metrics
            }, device)

        combined_metrics["train/accuracy"] /= combined_metrics[
            "train/epoch_size"]
        combined_metrics["train/loss"] /= combined_metrics["train/epoch_size"]
        combined_metrics["validation/accuracy"] /= combined_metrics[
            "validation/epoch_size"]
        combined_metrics["validation/loss"] /= combined_metrics[
            "validation/epoch_size"]

        if rank == 0:
            for key, value in combined_metrics.items():
                writer.add_scalar(key, value, global_step=epoch)

            if combined_metrics[
                    "validation/accuracy"] > best_validation_accuracy:
                best_validation_accuracy = combined_metrics[
                    "validation/accuracy"]

                model_to_save = model.module if hasattr(model,
                                                        'module') else model
                torch.save(
                    dict(epoch=epoch,
                         model_state_dict=model_to_save.state_dict(),
                         optimizer_state_dict=optimizer.state_dict(),
                         args=args), os.path.join(logdir, "best-model.pt"))
    def __init__(self, model, regime=None,
                 criterion=None,
                 label_smoothing=0,
                 target_forcing='teacher',
                 print_freq=10,
                 eval_freq=1000,
                 save_freq=1000,
                 grad_clip=None,
                 embedding_grad_clip=None,
                 max_tokens=None,
                 chunk_batch=1,
                 duplicates=1,
                 save_info={},
                 save_path='.',
                 checkpoint_filename='checkpoint%s.pth',
                 keep_checkpoints=5,
                 avg_loss_time=True,
                 distributed=False,
                 local_rank=0,
                 dtype=torch.float,
                 loss_scale=1,
                 device_ids=None,
                 device="cuda"):
        super(Seq2SeqTrainer, self).__init__()
        self.model = model
        self.criterion = criterion or CrossEntropyLoss(
            ignore_index=PAD, smooth_eps=label_smoothing, reduction='sum', from_logits=False)

        self.optimizer = OptimRegime(
            self.model, regime=regime, use_float_copy=dtype == torch.float16)

        if target_forcing == 'teacher':
            self.target_forcing = TeacherForcing(batch_first=self.batch_first)
        else:
            self.target_forcing = DecodedInputTargets()
        self.grad_clip = grad_clip
        self.embedding_grad_clip = embedding_grad_clip
        self.epoch = 0
        self.training_steps = 0
        self.save_info = save_info
        self.device = device
        self.dtype = dtype
        self.loss_scale = loss_scale
        self.max_tokens = max_tokens
        self.chunk_batch = chunk_batch
        self.duplicates = duplicates
        self.print_freq = print_freq
        self.eval_freq = eval_freq
        self.perplexity = float('inf')
        self.device_ids = device_ids
        self.avg_loss_time = avg_loss_time
        self.model_with_loss = AddLossModule(self.model, self.criterion)
        self.distributed = distributed
        self.local_rank = local_rank
        if distributed:
            self.model_with_loss = DistributedDataParallel(
                self.model_with_loss,
                device_ids=[local_rank],
                output_device=local_rank)
        else:
            if isinstance(self.device_ids, tuple):
                self.model_with_loss = DataParallel(self.model_with_loss,
                                                    self.device_ids,
                                                    dim=0 if self.batch_first else 1)
        self.save_path = save_path
        self.save_freq = save_freq
        self.checkpoint_filename = checkpoint_filename
        self.keep_checkpoints = keep_checkpoints + 1
        results_file = os.path.join(save_path, 'results')
        self.results = ResultsLog(results_file,
                                  params=save_info.get('config', None))
        self.watcher = None
        self.streams = {}
    def __init__(self, opt):
        super(VideoAttentionModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            train_opt = opt['train']
            self.patch_size = opt['datasets']['train']['patch_size']
            self.patch_repeat = opt['datasets']['train']['patch_repeat']
            self.use_diff = opt['datasets']['train']['use_diff']

            self.netG.train()

            #### loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
            elif loss_type == 'bce':
                self.cri_pix = nn.BCEWithLogitsLoss(reduction='sum').to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            #### optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            if train_opt['ft_tsa_only']:
                normal_params = []
                tsa_fusion_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if 'tsa_fusion' in k:
                            tsa_fusion_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning('Params [{:s}] will not optimize.'.format(k))
                optim_params = [
                    {  # add normal params first
                        'params': normal_params,
                        'lr': train_opt['lr_G']
                    },
                    {
                        'params': tsa_fusion_params,
                        'lr': train_opt['lr_G']
                    },
                ]
            else:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning('Params [{:s}] will not optimize.'.format(k))

            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'], train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            #### schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
                                                         restarts=train_opt['restarts'],
                                                         weights=train_opt['restart_weights'],
                                                         gamma=train_opt['lr_gamma'],
                                                         clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError()

            self.log_dict = OrderedDict()
Exemple #24
0
def main(rank):
    env = gym.make("CartPole-v0")
    observe_dim = 4
    action_num = 2
    max_episodes = 2000
    max_steps = 200
    solved_reward = 190
    solved_repeat = 5

    # initlize distributed world first
    world = World(world_size=4, rank=rank, name=str(rank), rpc_timeout=20)

    servers = model_server_helper(model_num=1)
    apex_group = world.create_rpc_group("apex", ["0", "1", "2", "3"])

    if rank in (2, 3):
        # learner_group.group is the wrapped torch.distributed.ProcessGroup
        learner_group = world.create_collective_group(ranks=[2, 3])

        # wrap the model with DistributedDataParallel
        # if current process is learner process 2 or 3
        q_net = DistributedDataParallel(module=QNet(observe_dim, action_num),
                                        process_group=learner_group.group)
        q_net_t = DistributedDataParallel(module=QNet(observe_dim, action_num),
                                          process_group=learner_group.group)
    else:
        q_net = QNet(observe_dim, action_num)
        q_net_t = QNet(observe_dim, action_num)

    # we may use a smaller batch size to train if we are using
    # DistributedDataParallel
    dqn_apex = DQNApex(q_net,
                       q_net_t,
                       t.optim.Adam,
                       nn.MSELoss(reduction='sum'),
                       apex_group,
                       servers,
                       batch_size=50)

    # synchronize all processes in the group, make sure
    # distributed buffer has been created on all processes in apex_group
    apex_group.barrier()

    # manually control syncing to improve performance
    dqn_apex.set_sync(False)
    if rank in (0, 1):
        # Process 0 and 1 are workers(samplers)
        # begin training
        episode, step, reward_fulfilled = 0, 0, 0
        smoothed_total_reward = 0

        while episode < max_episodes:
            # sleep to wait for learners keep up
            sleep(0.1)
            episode += 1
            total_reward = 0
            terminal = False
            step = 0

            state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim)

            # manually pull the newest parameters
            dqn_apex.manual_sync()
            while not terminal and step <= max_steps:
                step += 1
                with t.no_grad():
                    old_state = state
                    # agent model inference
                    action = dqn_apex.act_discrete_with_noise(
                        {"state": old_state})
                    state, reward, terminal, _ = env.step(action.item())
                    state = t.tensor(state, dtype=t.float32)\
                        .view(1, observe_dim)
                    total_reward += reward

                    dqn_apex.store_transition({
                        "state": {
                            "state": old_state
                        },
                        "action": {
                            "action": action
                        },
                        "next_state": {
                            "state": state
                        },
                        "reward":
                        reward,
                        "terminal":
                        terminal or step == max_steps
                    })

            smoothed_total_reward = (smoothed_total_reward * 0.9 +
                                     total_reward * 0.1)
            logger.info("Process {} Episode {} total reward={:.2f}".format(
                rank, episode, smoothed_total_reward))

            if smoothed_total_reward > solved_reward:
                reward_fulfilled += 1
                if reward_fulfilled >= solved_repeat:
                    logger.info("Environment solved!")

                    # will cause torch RPC to complain
                    # since other processes may have not finished yet.
                    # just for demonstration.
                    exit(0)
            else:
                reward_fulfilled = 0

    elif rank in (2, 3):
        # wait for enough samples
        while dqn_apex.replay_buffer.all_size() < 500:
            sleep(0.1)
        while True:
            dqn_apex.update()
Exemple #25
0
    def __init__(
        self,
        model: Model,
        optimizer: torch.optim.Optimizer,
        data_loader: torch.utils.data.DataLoader,
        patience: Optional[int] = None,
        validation_metric: str = "-loss",
        validation_data_loader: torch.utils.data.DataLoader = None,
        num_epochs: int = 20,
        serialization_dir: Optional[str] = None,
        checkpointer: Checkpointer = None,
        model_save_interval: float = None,
        cuda_device: int = -1,
        grad_norm: Optional[float] = None,
        grad_clipping: Optional[float] = None,
        learning_rate_scheduler: Optional[LearningRateScheduler] = None,
        momentum_scheduler: Optional[MomentumScheduler] = None,
        tensorboard_writer: TensorboardWriter = None,
        log_batch_size_period: Optional[int] = None,
        moving_average: Optional[MovingAverage] = None,
        distributed: bool = False,
        local_rank: int = 0,
        world_size: int = 1,
        num_gradient_accumulation_steps: int = 1,
        opt_level: Optional[str] = None,
    ) -> None:
        """
        A trainer for doing supervised learning. It just takes a labeled dataset
        and a `DataLoader`, and uses the supplied `Optimizer` to learn the weights
        for your model over some fixed number of epochs. You can also pass in a validation
        dataloader and enable early stopping. There are many other bells and whistles as well.

        # Parameters

        model : `Model`, required.
            An AllenNLP model to be optimized. Pytorch Modules can also be optimized if
            their `forward` method returns a dictionary with a "loss" key, containing a
            scalar tensor representing the loss function to be optimized.

            If you are training your model using GPUs, your model should already be
            on the correct device. (If you use `Trainer.from_params` this will be
            handled for you.)
        optimizer : `torch.nn.Optimizer`, required.
            An instance of a Pytorch Optimizer, instantiated with the parameters of the
            model to be optimized.
        data_loader : `DataLoader`, required.
            A pytorch `DataLoader` containing your `Dataset`, yielding padded indexed batches.
        patience : Optional[int] > 0, optional (default=None)
            Number of epochs to be patient before early stopping: the training is stopped
            after `patience` epochs with no improvement. If given, it must be `> 0`.
            If None, early stopping is disabled.
        validation_metric : str, optional (default="loss")
            Validation metric to measure for whether to stop training using patience
            and whether to serialize an `is_best` model each epoch. The metric name
            must be prepended with either "+" or "-", which specifies whether the metric
            is an increasing or decreasing function.
        validation_dataloader : `DataLoader`, optional (default=None)
            A `DataLoader` to use for the validation set.  If `None`, then
            use the training `DataLoader` with the validation data.
        num_epochs : int, optional (default = 20)
            Number of training epochs.
        serialization_dir : str, optional (default=None)
            Path to directory for saving and loading model files. Models will not be saved if
            this parameter is not passed.
        checkpointer : `Checkpointer`, optional (default=None)
            A `Checkpointer` is responsible for periodically saving model weights.  If none is given
            here, we will construct one with default parameters.
        model_save_interval : `float`, optional (default=None)
            If provided, then serialize models every `model_save_interval`
            seconds within single epochs.  In all cases, models are also saved
            at the end of every epoch if `serialization_dir` is provided.
        cuda_device : `int`, optional (default = -1)
            An integer specifying the CUDA device(s) to use for this process. If -1, the CPU is used.
            Data parallelism is controlled at the allennlp train level, so each trainer will have a single
            GPU.
        grad_norm : `float`, optional, (default = None).
            If provided, gradient norms will be rescaled to have a maximum of this value.
        grad_clipping : `float`, optional (default = `None`).
            If provided, gradients will be clipped `during the backward pass` to have an (absolute)
            maximum of this value.  If you are getting `NaNs` in your gradients during training
            that are not solved by using `grad_norm`, you may need this.
        learning_rate_scheduler : `LearningRateScheduler`, optional (default = None)
            If specified, the learning rate will be decayed with respect to
            this schedule at the end of each epoch (or batch, if the scheduler implements
            the `step_batch` method). If you use `torch.optim.lr_scheduler.ReduceLROnPlateau`,
            this will use the `validation_metric` provided to determine if learning has plateaued.
            To support updating the learning rate on every batch, this can optionally implement
            `step_batch(batch_num_total)` which updates the learning rate given the batch number.
        momentum_scheduler : `MomentumScheduler`, optional (default = None)
            If specified, the momentum will be updated at the end of each batch or epoch
            according to the schedule.
        tensorboard_writer : `TensorboardWriter`, optional
            If this is not provided, we will construct a `TensorboardWriter` with default
            parameters and use that.
        log_batch_size_period : `int`, optional, (default = `None`)
            If defined, how often to log the average batch size.
        moving_average : `MovingAverage`, optional, (default = None)
            If provided, we will maintain moving averages for all parameters. During training, we
            employ a shadow variable for each parameter, which maintains the moving average. During
            evaluation, we backup the original parameters and assign the moving averages to corresponding
            parameters. Be careful that when saving the checkpoint, we will save the moving averages of
            parameters. This is necessary because we want the saved model to perform as well as the validated
            model if we load it later. But this may cause problems if you restart the training from checkpoint.
        distributed : `bool`, optional, (default = False)
            If set, PyTorch's `DistributedDataParallel` is used to train the model in multiple GPUs. This also
            requires `world_size` to be greater than 1.
        local_rank : `int`, optional, (default = 0)
            This is the unique identifier of the `Trainer` in a distributed process group. The GPU device id is
            used as the rank.
        world_size : `int`, (default = 1)
            The number of `Trainer` workers participating in the distributed training.
        num_gradient_accumulation_steps : `int`, optional, (default = 1)
            Gradients are accumulated for the given number of steps before doing an optimizer step. This can
            be useful to accommodate batches that are larger than the RAM size. Refer Thomas Wolf's
            [post](https://tinyurl.com/y5mv44fw) for details on Gradient Accumulation.
        opt_level : `str`, optional, (default = `None`)
            Each opt_level establishes a set of properties that govern Amp’s implementation of pure or mixed
            precision training. Must be a choice of `"O0"`, `"O1"`, `"O2"`, or `"O3"`.
            See the Apex [documentation](https://nvidia.github.io/apex/amp.html#opt-levels-and-properties) for
            more details. If `None`, Amp is not used. Defaults to `None`.
        """
        super().__init__(serialization_dir, cuda_device, distributed,
                         local_rank, world_size)

        # I am not calling move_to_gpu here, because if the model is
        # not already on the GPU then the optimizer is going to be wrong.
        self.model = model

        self.data_loader = data_loader
        self._validation_data_loader = validation_data_loader
        self.optimizer = optimizer

        if patience is None:  # no early stopping
            if validation_data_loader:
                logger.warning(
                    "You provided a validation dataset but patience was set to None, "
                    "meaning that early stopping is disabled")
        elif (not isinstance(patience, int)) or patience <= 0:
            raise ConfigurationError(
                '{} is an invalid value for "patience": it must be a positive integer '
                "or None (if you want to disable early stopping)".format(
                    patience))

        # For tracking is_best_so_far and should_stop_early
        self._metric_tracker = MetricTracker(patience, validation_metric)
        # Get rid of + or -
        self._validation_metric = validation_metric[1:]

        self._num_epochs = num_epochs

        if checkpointer is not None:
            self._checkpointer = checkpointer
        else:
            self._checkpointer = Checkpointer(serialization_dir)

        self._model_save_interval = model_save_interval

        self._grad_norm = grad_norm
        self._grad_clipping = grad_clipping

        self._learning_rate_scheduler = learning_rate_scheduler
        self._momentum_scheduler = momentum_scheduler
        self._moving_average = moving_average

        # We keep the total batch number as an instance variable because it
        # is used inside a closure for the hook which logs activations in
        # `_enable_activation_logging`.
        self._batch_num_total = 0

        self._tensorboard = tensorboard_writer or TensorboardWriter(
            serialization_dir)
        self._tensorboard.get_batch_num_total = lambda: self._batch_num_total
        self._tensorboard.enable_activation_logging(self.model)

        self._log_batch_size_period = log_batch_size_period

        self._last_log = 0.0  # time of last logging

        self._num_gradient_accumulation_steps = num_gradient_accumulation_steps

        # Enable automatic mixed precision training with NVIDIA Apex.
        self._opt_level = opt_level
        if self._opt_level is not None:
            if amp is None:
                raise ConfigurationError((
                    "Apex not installed but opt_level was provided. Please install NVIDIA's Apex to enable"
                    " automatic mixed precision (AMP) training. See: https://github.com/NVIDIA/apex."
                ))

            self.model, self.optimizer = amp.initialize(
                self.model, self.optimizer, opt_level=self._opt_level)

        # Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its
        # usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model`
        # will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc.
        #
        # Hence a reference to Pytorch's object is maintained in the case of distributed training and in the
        # normal case, reference to `Model` is retained. This reference is only used in
        # these places: `model.__call__`, `model.train` and `model.eval`.
        if self._distributed:
            self._pytorch_model = DistributedDataParallel(
                self.model,
                device_ids=[self.cuda_device],
                find_unused_parameters=True)
        else:
            self._pytorch_model = self.model
Exemple #26
0
def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
    model.train()
    from functools import reduce
    import operator

    num_params = reduce(operator.add, (reduce(operator.mul, x.size())
                                       for x in model.parameters()))
    if model.group:
        total = torch.Tensor([num_params])
        if torch.cuda.is_available():
            total = total.cuda()
        torch.distributed.all_reduce(total, group=model.group)
        logging.info(
            f"training model, #prams = {num_params}, group: {model.group.rank()}, grank:"
            f" {torch.distributed.get_rank()}, sizes {model.group.size()}")
        torch.distributed.barrier()
        if model.group.rank() == 0:
            logging.info(f"total #prams = {total.item()}")
    else:
        logging.info(f"training model, #prams = {num_params}")
    vocab_size = 10000  # FIXME
    total_loss = 0.0
    start_time = time.time()
    word_counter = 0

    optimizer = optimizer(model)

    def get_first_device(model):
        if isinstance(model, DDP):
            model = model.module

        if not torch.cuda.is_available():
            return torch.device("cpu")
        if model.devices:
            return model.devices[0]
        else:
            return torch.cuda.current_device()

    def get_last_device(model):
        if isinstance(model, DDP):
            model = model.module

        if not torch.cuda.is_available():
            return torch.device("cpu")
        if model.devices:
            return model.devices[-1]
        else:
            return torch.cuda.current_device()

    pipe_group = model.group

    if args.ddp_zero:
        model = DDP(
            model,
            device_ids=[torch.cuda.current_device()],
            process_group=get_data_parallel_group(),
            find_unused_parameters=False,
        )

    if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (
            pipe_group.size() - 1):
        thing = {"input": torch.zeros(args.batch_size)}

        class FakeDataset:
            def __getitem__(self, index):
                return thing

            def __len__(self):
                return len(lm_dataloader)

        lm_dataloader = FakeDataset()

    for i, batch in enumerate(lm_dataloader):
        bi = batch["input"]
        if args.max_batch and i > args.max_batch:
            break
        optimizer.zero_grad()
        try:
            if (pipe_group is None
                    or pipe_group.rank() == 0) and not args.ddp_zero:
                tmp = batch["input"].to(get_first_device(model))
                output = model(tmp)
            else:
                output = model(batch["input"])
        except Exception as e:
            raise RuntimeError(
                f"training failed on {torch.distributed.get_rank()}") from e

        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
            target = batch["target"].to(get_last_device(model))
            output = output.to(target.device)

            loss = criterion(output.view(-1, vocab_size), target.view(-1))
            if args.ddp_zero:
                ddp_group = get_data_parallel_group()
                torch.distributed.all_reduce(loss,
                                             op=torch.distributed.ReduceOp.SUM,
                                             group=ddp_group)
                loss /= ddp_group.size()
            loss.backward()
            del target
        else:
            if args.ddp_zero:
                model.module.back_helper(output)
            else:
                model.back_helper(output)

        del output

        torch.nn.utils.clip_grad_value_(model.parameters(), 0.05)
        optimizer.step()

        if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
            total_loss += loss.item()
            log_interval = 1
            word_counter += batch["ntokens"]
            if i % log_interval == 0 and i > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
                print(
                    "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".
                    format(i, word_counter / elapsed, cur_loss,
                           math.exp(cur_loss)))
                word_counter = 0
                total_loss = 0
                start_time = time.time()
Exemple #27
0
def train(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    elif not os.path.exists(args.dir):
        # create 40 random image, mask paris for training
        print(
            f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(40):
            im, seg = create_test_image_3d(128,
                                           128,
                                           128,
                                           num_seg_classes=1,
                                           channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    train_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        RandCropByPosNegLabeld(keys=["img", "seg"],
                               label_key="seg",
                               spatial_size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
        ToTensord(keys=["img", "seg"]),
    ])

    # partition dataset based on current rank number, every rank trains with its own data
    # it can avoid duplicated caching content in each rank, but will not do global shuffle before every epoch
    data_part = partition_dataset(
        data=train_files,
        num_partitions=dist.get_world_size(),
        shuffle=True,
        even_divisible=True,
    )[dist.get_rank()]

    train_ds = SmartCacheDataset(
        data=data_part,
        transform=train_transforms,
        replace_rate=0.2,
        cache_num=
        15,  # we suppose to use 2 ranks in this example, every rank has 20 training images
        num_init_workers=2,
        num_replace_workers=2,
    )
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=True)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[device])

    # start a typical PyTorch training
    epoch_loss_values = list()
    # start the replacement thread of SmartCache
    train_ds.start()

    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(
                device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        # replace 20% of cache content for next epoch
        train_ds.update_cache()
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    # stop replacement thread of SmartCache
    train_ds.shutdown()
    print(f"train completed, epoch losses: {epoch_loss_values}")
    if dist.get_rank() == 0:
        # all processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes,
        # therefore, saving it in one process is sufficient
        torch.save(model.state_dict(), "final_model.pth")
    dist.destroy_process_group()
Exemple #28
0
def train(cfg):
    # Set seeds for determinism
    torch.manual_seed(cfg.training.seed)
    torch.cuda.manual_seed_all(cfg.training.seed)
    np.random.seed(cfg.training.seed)
    random.seed(cfg.training.seed)

    main_proc = True
    device = torch.device("cpu" if cfg.training.no_cuda else "cuda")

    is_distributed = os.environ.get(
        "LOCAL_RANK")  # If local rank exists, distributed env

    if is_distributed:
        # when using NCCL, on failures, surviving nodes will deadlock on NCCL ops
        # because NCCL uses a spin-lock on the device. Set this env var and
        # to enable a watchdog thread that will destroy stale NCCL communicators
        os.environ["NCCL_BLOCKING_WAIT"] = "1"

        device_id = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(device_id)
        print(f"Setting CUDA Device to {device_id}")

        dist.init_process_group(backend=cfg.training.dist_backend.value)
        main_proc = device_id == 0  # Main process handles saving of models and reporting

    if OmegaConf.get_type(cfg.checkpointing) == FileCheckpointConfig:
        checkpoint_handler = FileCheckpointHandler(cfg=cfg.checkpointing)
    elif OmegaConf.get_type(cfg.checkpointing) == GCSCheckpointConfig:
        checkpoint_handler = GCSCheckpointHandler(cfg=cfg.checkpointing)
    else:
        raise ValueError("Checkpoint Config has not been specified correctly.")

    if main_proc and cfg.visualization.visdom:
        visdom_logger = VisdomLogger(id=cfg.visualization.id,
                                     num_epochs=cfg.training.epochs)
    if main_proc and cfg.visualization.tensorboard:
        tensorboard_logger = TensorBoardLogger(
            id=cfg.visualization.id,
            log_dir=to_absolute_path(cfg.visualization.log_dir),
            log_params=cfg.visualization.log_params)

    if cfg.checkpointing.load_auto_checkpoint:
        latest_checkpoint = checkpoint_handler.find_latest_checkpoint()
        if latest_checkpoint:
            cfg.checkpointing.continue_from = latest_checkpoint

    if cfg.checkpointing.continue_from:  # Starting from previous model
        state = TrainingState.load_state(
            state_path=to_absolute_path(cfg.checkpointing.continue_from))
        model = state.model
        if cfg.training.finetune:
            state.init_finetune_states(cfg.training.epochs)

        if main_proc and cfg.visualization.visdom:  # Add previous scores to visdom graph
            visdom_logger.load_previous_values(state.epoch, state.results)
        if main_proc and cfg.visualization.tensorboard:  # Previous scores to tensorboard logs
            tensorboard_logger.load_previous_values(state.epoch, state.results)
    else:
        # Initialise new model training
        with open(to_absolute_path(cfg.data.labels_path)) as label_file:
            labels = json.load(label_file)

        if OmegaConf.get_type(cfg.model) is BiDirectionalConfig:
            model = DeepSpeech(
                rnn_hidden_size=cfg.model.hidden_size,
                nb_layers=cfg.model.hidden_layers,
                labels=labels,
                rnn_type=supported_rnns[cfg.model.rnn_type.value],
                audio_conf=cfg.data.spect,
                bidirectional=True)
        elif OmegaConf.get_type(cfg.model) is UniDirectionalConfig:
            model = DeepSpeech(
                rnn_hidden_size=cfg.model.hidden_size,
                nb_layers=cfg.model.hidden_layers,
                labels=labels,
                rnn_type=supported_rnns[cfg.model.rnn_type.value],
                audio_conf=cfg.data.spect,
                bidirectional=False,
                context=cfg.model.lookahead_context)
        else:
            raise ValueError("Model Config has not been specified correctly.")

        state = TrainingState(model=model)
        state.init_results_tracking(epochs=cfg.training.epochs)

    # Data setup
    evaluation_decoder = GreedyDecoder(
        model.labels)  # Decoder used for validation
    train_dataset = SpectrogramDataset(audio_conf=model.audio_conf,
                                       manifest_filepath=to_absolute_path(
                                           cfg.data.train_manifest),
                                       labels=model.labels,
                                       normalize=True,
                                       augmentation_conf=cfg.data.augmentation)
    test_dataset = SpectrogramDataset(audio_conf=model.audio_conf,
                                      manifest_filepath=to_absolute_path(
                                          cfg.data.val_manifest),
                                      labels=model.labels,
                                      normalize=True)
    if not is_distributed:
        train_sampler = DSRandomSampler(dataset=train_dataset,
                                        batch_size=cfg.data.batch_size,
                                        start_index=state.training_step)
    else:
        train_sampler = DSElasticDistributedSampler(
            dataset=train_dataset,
            batch_size=cfg.data.batch_size,
            start_index=state.training_step)
    train_loader = AudioDataLoader(dataset=train_dataset,
                                   num_workers=cfg.data.num_workers,
                                   batch_sampler=train_sampler)
    test_loader = AudioDataLoader(dataset=test_dataset,
                                  num_workers=cfg.data.num_workers,
                                  batch_size=cfg.data.batch_size)

    model = model.to(device)
    parameters = model.parameters()
    if OmegaConf.get_type(cfg.optim) is SGDConfig:
        optimizer = torch.optim.SGD(parameters,
                                    lr=cfg.optim.learning_rate,
                                    momentum=cfg.optim.momentum,
                                    nesterov=True,
                                    weight_decay=cfg.optim.weight_decay)
    elif OmegaConf.get_type(cfg.optim) is AdamConfig:
        optimizer = torch.optim.AdamW(parameters,
                                      lr=cfg.optim.learning_rate,
                                      betas=cfg.optim.betas,
                                      eps=cfg.optim.eps,
                                      weight_decay=cfg.optim.weight_decay)
    else:
        raise ValueError("Optimizer has not been specified correctly.")

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=not cfg.training.no_cuda,
                                      opt_level=cfg.apex.opt_level,
                                      loss_scale=cfg.apex.loss_scale)
    if state.optim_state is not None:
        optimizer.load_state_dict(state.optim_state)
    if state.amp_state is not None:
        amp.load_state_dict(state.amp_state)

    # Track states for optimizer/amp
    state.track_optim_state(optimizer)
    if not cfg.training.no_cuda:
        state.track_amp_state(amp)

    if is_distributed:
        model = DistributedDataParallel(model, device_ids=[device_id])
    print(model)
    print("Number of parameters: %d" % DeepSpeech.get_param_size(model))

    criterion = CTCLoss()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    for epoch in range(state.epoch, cfg.training.epochs):
        model.train()
        end = time.time()
        start_epoch_time = time.time()
        state.set_epoch(epoch=epoch)
        train_sampler.set_epoch(epoch=epoch)
        train_sampler.reset_training_step(training_step=state.training_step)
        for i, (data) in enumerate(train_loader, start=state.training_step):
            state.set_training_step(training_step=i)
            inputs, targets, input_percentages, target_sizes = data
            input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
            # measure data loading time
            data_time.update(time.time() - end)
            inputs = inputs.to(device)

            out, output_sizes = model(inputs, input_sizes)
            out = out.transpose(0, 1)  # TxNxH

            float_out = out.float()  # ensure float32 for loss
            loss = criterion(float_out, targets, output_sizes,
                             target_sizes).to(device)
            # loss = loss / inputs.size(0)  # average the loss by minibatch
            loss_value = loss.item()

            # Check to ensure valid loss was calculated
            valid_loss, error = check_loss(loss, loss_value)
            if valid_loss:
                optimizer.zero_grad()

                # compute gradient
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               cfg.optim.max_norm)
                optimizer.step()
            else:
                print(error)
                print('Skipping grad update')
                loss_value = 0

            state.avg_loss += loss_value
            losses.update(loss_value, inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                      (epoch + 1), (i + 1),
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses))

            if main_proc and cfg.checkpointing.checkpoint_per_iteration:
                checkpoint_handler.save_iter_checkpoint_model(epoch=epoch,
                                                              i=i,
                                                              state=state)

            del loss, out, float_out

        state.avg_loss = state.avg_loss / len(
            train_dataset) * cfg.data.batch_size

        epoch_time = time.time() - start_epoch_time
        print('Training Summary Epoch: [{0}]\t'
              'Time taken (s): {epoch_time:.0f}\t'
              'Average Loss {loss:.3f}\t'.format(epoch + 1,
                                                 epoch_time=epoch_time,
                                                 loss=state.avg_loss))

        time.sleep(15 * 60)

        with torch.no_grad():
            wer, cer, output_data = run_evaluation(
                test_loader=test_loader,
                device=device,
                model=model,
                decoder=evaluation_decoder,
                target_decoder=evaluation_decoder)

            print('Validation Summary Epoch: [{0}]\t'
                  'Average WER {wer:.3f}\t'
                  'Average CER {cer:.3f}\t'.format(epoch + 1, wer=wer,
                                                   cer=cer))

        state.add_results(epoch=epoch,
                          loss_result=state.avg_loss,
                          wer_result=wer,
                          cer_result=cer)

        if main_proc and cfg.visualization.visdom:
            visdom_logger.update(epoch, state.result_state)
        if main_proc and cfg.visualization.tensorboard:
            tensorboard_logger.update(epoch, state.result_state,
                                      model.named_parameters())

        if main_proc and cfg.checkpointing.checkpoint:  # Save epoch checkpoint
            checkpoint_handler.save_checkpoint_model(epoch=epoch, state=state)
        # anneal lr
        for g in optimizer.param_groups:
            g['lr'] = g['lr'] / cfg.optim.learning_anneal
        print('Learning rate annealed to: {lr:.6f}'.format(lr=g['lr']))

        if main_proc and (state.best_wer is None or state.best_wer > wer):
            checkpoint_handler.save_best_model(epoch=epoch, state=state)
            state.set_best_wer(wer)
            state.reset_avg_loss()
        state.reset_training_step()  # Reset training step for next epoch
Exemple #29
0
class DistTrainer():
    r"""
    分布式的 Trainer,支持分布式训练和混合精度的训练。具体实现原理请阅读 pytorch 官方文档。

    Note: 使用分布式 Trainer 时会同时有多个进程执行训练代码。因此将单进程的训练代码改为多进程之前,
    请仔细检查,确保训练代码中的同步和互斥操作能正确执行(如模型保持,打印日志等)
    """
    def __init__(self, train_data, model, optimizer=None, loss=None,
                 callbacks_all=None, callbacks_master=None,
                 batch_size_per_gpu=8, n_epochs=1,
                 num_workers=1, drop_last=False,
                 dev_data=None, metrics=None, metric_key=None,
                 update_every=1, print_every=10, validate_every=-1,
                 save_path=None, device='auto',
                 fp16=False, use_tqdm=True, **kwargs):
        r"""

        :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。
        :param nn.modules model: 待训练的模型
        :param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器
        :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward`
        :param list callbacks_all: 用于在train过程中起调节作用的回调函数,作用于所有训练进程中。
            可使用的callback参见 :mod:`callback模块 <fastNLP.core.callback>`
        :param list callbacks_master: 用于在train过程中起调节作用的回调函数,只作用于其中一个进程( Master 进程)。
            可使用的callback参见 :mod:`callback模块 <fastNLP.core.callback>`
        :param int batch_size_per_gpu: 训练时,每个进程的 batch 大小。
        :param int n_epochs: 需要优化迭代多少次。
        :param num_workers: int, 有多少个线程来进行数据pad处理。
        :param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch
        :param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。
        :param metrics: 验证的评估函数。可以只使用一个 :class:`Metric<fastNLP.core.metrics.MetricBase>` ,
            也可以使用多个 :class:`Metric<fastNLP.core.metrics.MetricBase>` ,通过列表传入。
            如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None,
            则保存当前模型。Metric种类详见 :mod:`metrics模块 <fastNLP.core.metrics>` 。仅在传入dev_data时有效。
        :param str,None metric_key:  :class:`Metric<fastNLP.core.metrics.MetricBase>` 有时会有多个指标,
            比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需
            要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表
            明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。
        :param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128
            会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。
        :param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。
        :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。
        :param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存
            最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。
        :param str device: 指定 device,可以是 gpu,cpu 或 auto
        :param bool fp16: 指定是否使用半精度训练。
        :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
        :param kwargs: 支持配置可选参数
            bool test_use_tqdm: 在dev上验证的时候是否开启tqdm
            Sampler test_sampler: 在evaluate的时候使用的sampler
            int dev_batch_size: 在evaluate时,使用的evaluate的batch大小
            bool test_use_fp16: test时使用fp16
            bool set_grad_to_none: zero_grad时将grad设为None而不是0
            GradScaler gradscaler: 自定义的梯度 scaler
        """
        assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']"
        if device == 'auto':
            device = 'cuda' if torch.cuda.is_available() else 'cpu'

        # init distributed
        if device == 'cuda':
            torch.cuda.set_device(get_local_rank())
            self.device = torch.device("cuda", get_local_rank())
        else:
            self.device = torch.device(device)

        init_logger_dist()

        self.world_size = dist.get_world_size()
        self.rank = dist.get_rank() # unique id for each process

        self.train_data = train_data
        self.batch_size_per_gpu = int(batch_size_per_gpu)
        self.n_epochs = int(n_epochs)
        self.num_data_workers = int(num_workers)
        self.drop_last = drop_last
        self.update_every = int(update_every)
        self.print_every = int(print_every)
        self.validate_every = int(validate_every)
        self.save_path = save_path
        self.losser = _prepare_losser(loss)
        self.fp16 = fp16
        self.local_rank = get_local_rank()
        self._forward_func = model.forward
        self.callback_manager = DistCallbackManager(
            env={"trainer": self}, callbacks_all=callbacks_all,
            callbacks_master=callbacks_master)
        self.test_manager = DistCallbackManager(env={'trainer': self})
        self.metric_key = metric_key
        self.use_tqdm = use_tqdm

        model.to(self.device)

        # init fp16, must before DataParallel init
        autocast, GradScaler = _build_fp16_env(dummy=not self.fp16)
        self.auto_cast = autocast
        user_grad_scaler = getattr(kwargs, 'gradscaler', None)
        if user_grad_scaler is not None:
            assert self.fp16, "must set fp16=True to enable gradscaler"
            grad_scaler = user_grad_scaler
        else:
            grad_scaler = GradScaler()
        self.grad_scaler = grad_scaler

        self.set_grad_to_none = getattr(kwargs, 'set_grad_to_none', True)

        # init DataParallel
        if parse_version(torch.__version__)>=parse_version('1.1'):
            self.ddp_model = DDP(model, device_ids=[self.local_rank],
                             output_device=self.local_rank, find_unused_parameters=True)
        else:
            self.ddp_model = DDP(model, device_ids=[self.local_rank],
                             output_device=self.local_rank)
        self.model = self.ddp_model.module

        optimizer = self._get_optimizer(optimizer)
        self.optimizer = optimizer
        if isinstance(self.train_data, DataSet):
            self.sampler = DistributedSampler(self.train_data)
        self.data_iterator = self._get_data_iter(self.train_data)
        self.batch_size = self.world_size * self.batch_size_per_gpu
        self.n_steps = self._get_n_steps()

        self.dev_data = dev_data
        self.metrics = metrics
        self.test_use_tqdm = True
        self.kwargs = kwargs
        self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm)
        dev_batch_size = kwargs.get('dev_batch_size', batch_size_per_gpu)

        # for evaluation, only run eval on master proc
        if dev_data and metrics:
            cb = _TesterCallback(
                dev_data, model, metrics,
                batch_size=dev_batch_size, num_workers=num_workers, sampler=kwargs.get('test_sampler', None),
                use_tqdm=self.test_use_tqdm)
            self.test_manager.add_callback([cb], master=True)

        # Setup logging
        # 同步start_time
        sync_time = torch.tensor(time.time(), dtype=torch.double).to(self.device)
        dist.broadcast(sync_time, src=0)
        self.start_time = datetime.fromtimestamp(sync_time.item()).strftime('%Y-%m-%d-%H-%M-%S-%f')
        # print('sync_time: {}, start_time: {}'.format(sync_time, self.start_time))

        if self.save_path:
            self.cp_save_path = self.save_path
        else:
            self.cp_save_path = None
        # use INFO in the master, WARN for others
        self.logger = logger
        self.logger.info("Setup Distributed Trainer")
        self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
                        os.getpid(), self.rank, self.local_rank, self.device, self.fp16))
        self.logger.info("Num of processes: {}".format(self.world_size))
        self.logger.info("Use device: {}".format(device))

    def _maybe_no_sync(self):
        """
        Whenever *samples* contains more than one mini-batch, we
        want to accumulate gradients locally and only call
        all-reduce in the last backwards pass.
        """
        i = self.step % self.update_every
        if (
                self.world_size > 1
                and hasattr(self.ddp_model, "no_sync")
                and i != 0
        ):
            return self.ddp_model.no_sync()
        else:
            return contextlib.ExitStack()  # dummy contextmanager

    def _get_n_steps(self):
        return len(self.data_iterator) * self.n_epochs

    def _get_data_iter(self, dataset):
        if isinstance(dataset, DataSet):
            return DataSetIter(dataset=dataset, batch_size=self.batch_size_per_gpu, sampler=self.sampler,
                               num_workers=self.num_data_workers, drop_last=self.drop_last)
        elif isinstance(dataset, BatchIter):
            return dataset
        else:
            raise TypeError("train_data type {} not support".format(type(dataset)))

    def _get_optimizer(self, optimizer):
        if isinstance(optimizer, torch.optim.Optimizer):
            return optimizer
        elif isinstance(optimizer, Optimizer):
            return optimizer.construct_from_pytorch(self.ddp_model.parameters())
        elif optimizer is None:
            return torch.optim.Adam(self.ddp_model.parameters(), lr=4e-3)
        else:
            if not (hasattr(optimizer, 'step') and callable(optimizer.step)):
                raise TypeError("optimizer must have a callable step() function.")
            else:
                self.optimizer = optimizer
    @property
    def is_master(self):
        r"""是否是主进程"""
        return self.rank == 0

    def train(self, load_best_model=True, on_exception='auto'):
        r"""
        使用该函数使Trainer开始训练。

        :param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。
                支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出;
                'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception.
        :return dict: 返回一个字典类型的数据,
                内含以下内容::

                    seconds: float, 表示训练时长
                    以下三个内容只有在提供了dev_data的情况下会有。
                    best_eval: Dict of Dict, 表示evaluation的结果。第一层的key为Metric的名称,
                                第二层的key为具体的Metric
                    best_epoch: int,在第几个epoch取得的最佳值
                    best_step: int, 在第几个step(batch)更新取得的最佳值

        """
        try:
            self.logger.info("###### Training epochs started ######")
            self.logger.info('Total epochs: %d'% self.n_epochs)
            self.logger.info('Total steps: %d'% self.n_steps)
            self.logger.info('Num instances per GPU: %d'% self.batch_size_per_gpu)
            self.logger.info('Num of steps per update: %d' % self.update_every)
            self.logger.info('Total batch_size: %d'%
                             (self.batch_size_per_gpu * dist.get_world_size() * self.update_every))
            self.logger.info('Total num of samples: %d'% len(self.train_data))
            self.logger.info("Num of callbacks for all workers: {}".format(
                                len(self.callback_manager.callbacks_all)))
            self.logger.info("Num of callbacks for master workers: {}".format(
                                len(self.callback_manager.callbacks_master)))
            self.logger.info("Callbacks for all workers: {}".format(
                    [repr(cb) for cb in self.callback_manager.callbacks_all]))
            self.logger.info("Callbacks for master workers: {}".format(
                    [repr(cb) for cb in self.callback_manager.callbacks_master]))

            start_time = time.time()
            results = {}
            if self.n_epochs <= 0:
                self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs))
                results['seconds'] = 0.
                return results

            try:
                self.callback_manager.on_train_begin()
                self._train()
                self.callback_manager.on_train_end()

            except BaseException as e:
                self.callback_manager.on_exception(e)
                if on_exception == 'auto':
                    if not isinstance(e, (CallbackException, KeyboardInterrupt)):
                        raise e
                    else:
                        self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__))
                elif on_exception == 'raise':
                    raise e

            results['seconds'] = round(time.time() - start_time, 2)
            self.logger.info("###### Train finished ######")
            self.logger.info('Total train time: {} seconds.'. format(results['seconds']))
            if load_best_model and self.cp_save_path and len(self.test_manager.callbacks):
                self.load_check_point(self._best_save_name())
        finally:
            pass
        dist.barrier()
        return results

    def _train(self):
        dist.barrier()
        if not self.use_tqdm:
            from .utils import _pseudo_tqdm as inner_tqdm
        else:
            inner_tqdm = tqdm

        self.step = 0
        self.epoch = 0
        self.pbar = inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}',
                        leave=False, dynamic_ncols=True, disable=not self.is_master)
        pbar = self.pbar
        avg_loss = 0
        data_iterator = self.data_iterator
        self.ddp_model.zero_grad()
        for epoch in range(1, self.n_epochs + 1):
            self.epoch = epoch
            pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
            # early stopping
            self.callback_manager.on_epoch_begin()
            for batch_x, batch_y in data_iterator:
                self.step += 1
                self.ddp_model.train()
                _move_dict_value_to_device(batch_x, batch_y, device=self.device)
                indices = data_iterator.get_batch_indices()
                # negative sampling; replace unknown; re-weight batch_y
                self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
                with self.auto_cast():
                    prediction = self._data_forward(self.ddp_model, batch_x)
                    # edit prediction
                    self.callback_manager.on_loss_begin(batch_y, prediction)
                    loss = self._compute_loss(prediction, batch_y)

                avg_loss += loss.detach()

                # Is loss NaN or inf? requires_grad = False
                self.callback_manager.on_backward_begin(loss)
                self.grad_scaler.scale(loss).backward()
                self.callback_manager.on_backward_end()
                if self.step % self.update_every == 0:
                    self._update()
                self.callback_manager.on_step_end()

                if self.step % self.print_every == 0:
                    avg_loss = float(avg_loss) / self.print_every
                    print_output = "loss:{:<6.5f}".format(avg_loss)
                    pbar.update(self.print_every)
                    pbar.set_postfix_str(print_output)
                    avg_loss = 0

                self.callback_manager.on_batch_end()

                if (self.validate_every > 0 and self.step % self.validate_every == 0) and len(self.test_manager.callbacks):
                    self._do_validation()

            # ================= mini-batch end ==================== #
            if self.validate_every < 0 and len(self.test_manager.callbacks):
                self._do_validation()

            # lr decay; early stopping
            self.callback_manager.on_epoch_end()
        # =============== epochs end =================== #
        pbar.close()
        self.pbar = None
    # ============ tqdm end ============== #

    def _clear_grad_opt(self, optimizer):
        if self.set_grad_to_none:
            for group in optimizer.param_groups:
                for p in group['params']:
                    if p.grad is not None:
                        p.grad = None
        else:
            optimizer.zero_grad()

    def _update(self):
        r"""Perform weight update on a model.

        """
        self.grad_scaler.step(self.optimizer)
        self.grad_scaler.update()
        self._clear_grad_opt(self.optimizer)

    def _data_forward(self, network, x):
        x = _build_args(self._forward_func, **x)
        y = network(**x)
        if not isinstance(y, dict):
            raise TypeError(
                f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.")
        return y

    def _compute_loss(self, predict, truth):
        r"""Compute loss given prediction and ground truth.

        :param predict: prediction dict, produced by model.forward
        :param truth: ground truth dict, produced by batch_y
        :return: a scalar
        """
        loss = self.losser(predict, truth)
        if self.update_every > 1:
            loss = loss / self.update_every
        if loss.dim() > 0:
            loss = loss.mean()
        return loss

    def save_check_point(self, name=None, only_params=False):
        r"""保存当前模型"""
        # only master save models
        if name is None:
            name = 'checkpoint-{}.bin'.format(self.step)
        os.makedirs(self.cp_save_path, exist_ok=True)
        path = os.path.join(self.cp_save_path, name)
        self.logger.info("Save checkpoint to {}".format(path))
        model_to_save = self.ddp_model.module
        if only_params:
            model_to_save = model_to_save.state_dict()
        if self.is_master:
            torch.save(model_to_save, path)

    def load_check_point(self, name):
        path = os.path.join(self.cp_save_path, name)
        self.logger.info('reload best model from %s', path)
        model_load = torch.load(
            path,
            map_location=lambda s, l: default_restore_location(s, "cpu"))
        if not isinstance(model_load, dict):
            model_load = model_load.state_dict()
        self.model.load_state_dict(model_load)

    def _best_save_name(self, auto_fix=True):
        best_name = "best_" + "_".join([self.model.__class__.__name__, str(self.metric_key), self.start_time])
        return best_name

    def _do_validation(self):
        with self.ddp_model.no_sync():
            # 因为模型参数不更新,可以关闭同步
            self.callback_manager.on_valid_begin()
            eval_res = self.test_manager.on_valid_begin()
            eval_res = list(filter(lambda x: x is not None, eval_res))
            if len(eval_res):
                eval_res, is_better = list(zip(*eval_res))
                eval_res = eval_res[0]
                is_better = is_better[0]
            else:
                eval_res, is_better = None, None
            if self.metric_key is None and eval_res is not None:
                eval_res0 = list(eval_res.values())[0]
                self.metric_key = list(eval_res0.keys())[0]
            # logger.info('{}, {}'.format(eval_res, is_better))
            # save better model on master node
            if is_better is not None and self.cp_save_path:
                if is_better:
                    self.save_check_point(self._best_save_name(), only_params=False)
            dist.barrier()

            if not self.is_master and self.metric_key is None:
                # 主进程自动得到了metric_key,而其它进程没有
                prefix = 'best_' + self.model.__class__.__name__
                suffix = self.start_time
                fn_list = os.listdir(self.cp_save_path)
                fn_list = [fn for fn in fn_list if fn.startswith(prefix) and fn.endswith(suffix)]
                if len(fn_list) == 1:
                    best_name = fn_list[0]
                    self.metric_key = best_name[len(prefix):-len(suffix)].strip('_')
            # print('RANK {} metric_key {}'.format(self.rank, self.metric_key))
            self.callback_manager.on_valid_end(
                eval_res, self.metric_key, self.optimizer, is_better)
            self.ddp_model.train()

    def close(self):
        r"""关闭Trainer,销毁进程"""
        dist.destroy_process_group()
class DeepSpeedPlugin(DDPPlugin):
    distributed_backend = "deepspeed"
    DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH"

    def __init__(self,
                 zero_optimization: bool = True,
                 stage: int = 2,
                 cpu_offload: bool = False,
                 contiguous_gradients: bool = True,
                 overlap_comm: bool = True,
                 allgather_partitions: bool = True,
                 reduce_scatter: bool = True,
                 allgather_bucket_size: int = 2e8,
                 reduce_bucket_size: int = 2e8,
                 zero_allow_untested_optimizer: bool = True,
                 config: Optional[Union[Path, str, dict]] = None,
                 logging_level: int = logging.WARN,
                 num_nodes: int = 1,
                 parallel_devices: Optional[List[torch.device]] = None,
                 cluster_environment: Optional[ClusterEnvironment] = None,
                 loss_scale: float = 0,
                 initial_scale_power: int = 32,
                 loss_scale_window: int = 1000,
                 hysteresis: int = 2,
                 min_loss_scale: int = 1) -> None:
        """

        Provides capabilities to run training using the DeepSpeed library,
        with training optimizations for large billion parameter models.
        `For more information: https://www.deepspeed.ai/`.

        .. warning:: ``DeepSpeedPlugin`` is in beta and subject to change.

        Defaults have been set to enable ZeRO-Offload and some have been taken from the link below.
        These defaults have been set generally, but may require tuning for optimum performance based on your model size.
        `For more information: https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training`.

        Arguments:

            zero_optimization: Enable ZeRO optimization. This is only compatible with precision=16. (default: True)

            stage: Different stages of the ZeRO Optimizer. 0 is disabled,
                1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning (default: 2)

            cpu_offload: Enable offloading optimizer memory and computation to CPU

            contiguous_gradients: Copies gradients to a continuous buffer as they are produced.
                Avoids memory fragmentation during backwards. Useful when training large models. (default: True)

            overlap_comm: Overlap the reduction (synchronization) of gradients with the backwards computation.
                This is a speed optimization when training across multiple GPUs/machines. (default: True)

            allgather_partitions: All gather updated parameters at the end of training step,
                instead of using a series of broadcast collectives (default: True)

            reduce_scatter: Use reduce/scatter instead of allreduce to average gradients (default:True)

            allgather_bucket_size: Number of elements to allgather at once.
                Used to limit the memory required for larger model sizes, with a tradeoff with speed. (default: 2e8)

            reduce_bucket_size: Number of elements to reduce at once.
                Used to limit the memory required for larger model sizes, with a tradeoff with speed (default: 2e8)

            zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a
                DeepSpeed supported optimizer when using ZeRO (default: True)

            config: Pass in a deepspeed formatted config dict,
                or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json.
                All defaults will be ignored if a config is passed in. (Default: ``None``)

            logging_level: Set logging level for deepspeed. (Default: ``logging.WARN``)

            loss_scale: Loss scaling value for FP16 training.
                0.0 results in dynamic loss scaling, otherwise static (Default: 0)

            initial_scale_power: Power of the initial dynamic loss scale value. Loss scale is computed
                by ``2^initial_scale_power`` (Default: 32)

            loss_scale_window: Window in which to raise/lower the dynamic FP16 loss scaling value (Default: 1000)

            hysteresis: FP16 Delay shift in Dynamic Loss scaling (Default: 2)

            min_loss_scale: The minimum FP16 dynamic loss scaling value (Default: 1000)

        """
        if not _DEEPSPEED_AVAILABLE:
            raise MisconfigurationException(
                "To use the DeepSpeed plugin, you must have DeepSpeed installed."
                " pip install deepspeed mpi4py")
        super().__init__(parallel_devices=parallel_devices,
                         num_nodes=num_nodes,
                         cluster_environment=cluster_environment)
        self.config = self._load_config(config)
        if self.config is None:
            # User has not overridden config, set defaults
            self.config = self._create_default_config(
                zero_optimization,
                zero_allow_untested_optimizer,
                stage=stage,
                cpu_offload=cpu_offload,
                contiguous_gradients=contiguous_gradients,
                overlap_comm=overlap_comm,
                allgather_partitions=allgather_partitions,
                reduce_scatter=reduce_scatter,
                allgather_bucket_size=allgather_bucket_size,
                reduce_bucket_size=reduce_bucket_size)
        self._config_initialized = False
        deepspeed.utils.logging.logger.setLevel(logging_level)

        # default FP16 parameters.
        self.loss_scale = loss_scale
        self.initial_scale_power = initial_scale_power
        self.loss_scale_window = loss_scale_window
        self.hysteresis = hysteresis
        self.min_loss_scale = min_loss_scale

    def _load_config(self, config):
        if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
            rank_zero_info(
                f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable"
            )
            config = os.environ[self.DEEPSPEED_ENV_VAR]
        if isinstance(config, str) or isinstance(config, Path):
            if not os.path.isfile(config):
                raise MisconfigurationException(
                    f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
                )
            with open(config) as f:
                config = json.load(f)
        return config

    def pre_dispatch(self):
        self.set_world_ranks()
        self.init_ddp_connection(self.global_rank, self.world_size)

        self.init_deepspeed()

        # set warning rank
        rank_zero_only.rank = self.global_rank

        # set the ranks and devices
        self.dist.rank = self.global_rank
        self.dist.device = self.root_device
        self.barrier()

    def init_deepspeed(self):
        if not self._config_initialized:
            self._format_config()
            self._config_initialized = True

        precision = self.lightning_module.trainer.accelerator.precision
        model = LightningDeepSpeedModule(pl_module=self.model,
                                         precision=precision)

        if self.lightning_module.trainer and self.lightning_module.trainer.training:
            self._initialize_deepspeed_train(model)
        else:
            self._initialize_deepspeed_inference(model)

    def _init_scheduler_optimizer(self):
        optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers(
            self.lightning_module)
        if len(optimizers) > 1 or len(schedulers) > 1:
            raise MisconfigurationException(
                "DeepSpeed currently only supports single optimizer, single optional scheduler."
            )
        scheduler = schedulers[0]['scheduler'] if len(
            schedulers) == 1 else None
        optimizer = optimizers[0]
        return optimizer, scheduler, optimizer_frequencies

    def _initialize_deepspeed_train(self, model):
        optimizer, lightning_scheduler, optimizer_frequencies = None, None, None
        if "optimizer" not in self.config:
            rank_zero_info(
                "You have not specified an optimizer or scheduler within the DeepSpeed config."
                "Using `configure_optimizers` to define optimizer and scheduler."
            )
            optimizer, lightning_scheduler, optimizer_frequencies = self._init_scheduler_optimizer(
            )
        model_parameters = filter(lambda p: p.requires_grad,
                                  self.model.parameters())
        model, optimizer, _, lr_scheduler = deepspeed.initialize(
            args=SimpleNamespace(local_rank=self.local_rank),
            model=model,
            model_parameters=model_parameters,
            optimizer=optimizer,
            lr_scheduler=lightning_scheduler,
            config_params=self.config,
        )

        # set optimizer for save/load, but deepspeed manages the specific optimizer logic
        self.lightning_module.trainer.optimizers = [optimizer]
        self.model = model

    def _initialize_deepspeed_inference(self, model):
        # move the model to the correct device
        self.model_to_device()

        self.pre_configure_ddp()
        self.model = DistributedDataParallel(
            model,
            device_ids=self.determine_ddp_device_ids(),
            **self._ddp_kwargs,
        )

    def configure_scheduler(self, lr_scheduler):
        scheduler = _get_default_scheduler_config()
        scheduler["scheduler"] = lr_scheduler
        return [scheduler]

    @property
    def lightning_module(self):
        # the model may not be wrapped with DeepEngine & LightningDeepSpeedModule if calling this too early
        module = getattr(self.model, "module", self.model)
        return module.module if isinstance(
            module, LightningDeepSpeedModule) else module

    @property
    def distributed_sampler_kwargs(self):
        distributed_sampler_kwargs = dict(num_replicas=self.world_size,
                                          rank=self.global_rank)
        return distributed_sampler_kwargs

    def init_optimizers(self, trainer,
                        model: LightningModule) -> Tuple[List, List, List]:
        # Skip initializing optimizers here as DeepSpeed handles optimizers via config.
        # User may have specified config options instead in configure_optimizers, but this is handled
        # via `_initialize_deepspeed_train`
        return [], [], []  # empty optimizers, schedulers and frequencies

    def optimizer_step(self, optimizer: torch.optim.Optimizer,
                       lambda_closure: Callable, **kwargs):
        # note: We rely on the deepspeed engine to carry out the step rather than the optimizer.
        # internally, the engine has a reference to the optimizer already.
        self.model.step(**kwargs)

    def _format_config(self):
        if self.config is None:
            raise MisconfigurationException(
                "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config."
                " See: https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html#deepspeed"
            )
        self._format_batch_size_and_grad_accum_config()
        self._format_precision_config()

    def _format_batch_size_and_grad_accum_config(self):
        if "gradient_accumulation_steps" in self.config:
            raise MisconfigurationException(
                "Within the DeepSpeed config, do not set gradient_accumulation_steps"
                " as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer."
            )
        if "train_micro_batch_size_per_gpu" not in self.config:
            # train_micro_batch_size_per_gpu is used for throughput logging purposes
            # by default we use the batch size of the loader which may be incorrect if a batch sampler is passed
            batch_size = self.lightning_module.train_dataloader().batch_size
            self.config["train_micro_batch_size_per_gpu"] = batch_size
        self.config[
            "gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
        if "gradient_clipping" not in self.config:
            self.config[
                "gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val

    def _format_precision_config(self):

        amp_type = self.lightning_module.trainer.accelerator_connector.amp_type
        amp_level = self.lightning_module.trainer.accelerator_connector.amp_level
        precision = self.lightning_module.trainer.accelerator_connector.precision
        if precision == 16:
            if "fp16" not in self.config and amp_type == AMPType.NATIVE:
                # FP16 is a DeepSpeed standalone AMP implementation
                rank_zero_info("Enabling DeepSpeed FP16.")
                self.config["fp16"] = {
                    "enabled": True,
                    "loss_scale": self.loss_scale,
                    "initial_scale_power": self.initial_scale_power,
                    "loss_scale_window": self.loss_scale_window,
                    "hysteresis": self.hysteresis,
                    "min_loss_scale": self.min_loss_scale
                }
            elif "amp" not in self.config and amp_type == AMPType.APEX:
                rank_zero_only("Enabling DeepSpeed APEX Implementation.")
                self.config["amp"] = {
                    "enabled": True,
                    "opt_level": amp_level,
                }
        if "zero_optimization" in self.config and not ("amp" in self.config or
                                                       "fp16" in self.config):
            raise MisconfigurationException(
                "To use DeepSpeed ZeRO Optimization, you must set precision=16."
            )

    def _create_default_config(self, zero_optimization: bool,
                               zero_allow_untested_optimizer: bool,
                               **zero_kwargs) -> Dict:
        if zero_optimization:
            return {
                "zero_allow_untested_optimizer": zero_allow_untested_optimizer,
                "zero_optimization": zero_kwargs
            }
        return {}