Ejemplo n.º 1
0
 def __exit__(self, type, value, traceback):
     if getattr(self, 'scope', None):
         # In ir.cpp ResetScopeContext we ensure that we have no remaining scope
         # before marking step.
         del self.scope
     xm.mark_step()
     super().__exit__(type, value, traceback)
Ejemplo n.º 2
0
 def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
   retval = None
   xm.mark_step()
   if not sum(
       v.item() for v in optimizer_state["found_inf_per_device"].values()):
     retval = optimizer.step(*args, **kwargs)
   return retval
Ejemplo n.º 3
0
def run_single(args, m, n, k):

    dtype = args.dtype
    device = args.device
    warmups = args.warmups
    steps = args.steps

    dt = torch.float32
    if (dtype == "float16" or dtype == "half"):
        dt = torch.float16
    elif (dtype == "bfloat16"):
        dt = torch.bfloat16

    torch.manual_seed(0)

    elap = 0.0

    a = torch.randn(m, k).to(dt)
    b = torch.randn(k, n).to(dt)
    c = torch.zeros(m, n).to(dt)

    if device == 'cpu':

        measure_cpu(a, b, warmups)
        elap = measure_cpu(a, b, steps)

    elif device == 'gpu':

        if torch.cuda.is_available():
            # ncuda = torch.cuda.device_count()
            # print("There are {} cuda devices".format(ncuda))
            # print("The first cuda device name is {} ".format(torch.cuda.get_device_name()))
            cuda0 = torch.device('cuda:0')
            with torch.cuda.device(cuda0):
                acuda = a.to(cuda0)
                bcuda = b.to(cuda0)
                measure_gpu(acuda, bcuda, warmups)
                elap = measure_gpu(acuda, bcuda, steps)
        else:
            print("CUDA is not available")
            sys.exit(1)

    else:
        # import torch_xla
        import torch_xla.core.xla_model as xm

        # alldev = xm.get_xla_supported_devices()
        # allrealdev = xm.xla_real_devices(alldev)
        # print("Found {0} XLA devices: {1}".format(len(allrealdev), allrealdev))

        dev = xm.xla_device()
        a = a.to(dev)
        b = b.to(dev)
        c = c.to(dev)
        measure_xla(a, b, warmups)
        xm.mark_step()
        elap = measure_xla(a, b, steps)
        xm.mark_step()

    return elap
def train_tpu(
    model,
    device,
    optimizer,
    data_type,
    input_size,
    output_size,
    batch_size,
    args,
):
    import torch_xla.core.xla_model as xm
    loss_f = nn.CrossEntropyLoss().to(device)

    # model.train()
    start_time = time.time()

    for i in range(args.steps + args.warmups):
        data = torch.randn(batch_size, input_size, device=device)
        target = torch.randint(output_size, [batch_size],
                               device=device,
                               dtype=torch.long)
        # data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data).float()
        loss = loss_f(output, target)
        loss.backward()
        optimizer.step()
        xm.mark_step()
        if i < args.warmups:
            start_time = time.time()

    return time.time() - start_time, loss
Ejemplo n.º 5
0
    def train_loop_fn(loader, epoch):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            with autocast():
                output = model(data)
                loss = loss_fn(output, target)

            scaler.scale(loss).backward()
            gradients = xm._fetch_gradients(optimizer)
            xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
            scaler.step(optimizer)
            scaler.update()
            xm.mark_step()
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()

            import resource
            print(f" CPU Usage After: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}")

            if step % FLAGS.log_steps == 0:
                # _train_update(device, step, loss, tracker, epoch, writer)
                xm.add_step_closure(
                    _train_update, args=(device, step, loss, tracker, epoch, writer)
                )
def distributed_init(args):
    if args.distributed_world_size == 1:
        raise ValueError(
            'Cannot initialize distributed with distributed_world_size=1')

    if not getattr(args, 'tpu', False):
        if torch.distributed.is_initialized():
            warnings.warn(
                'Distributed is already initialized, cannot initialize twice!')
        else:
            logger.info('distributed init (rank {}): {}'.format(
                args.distributed_rank,
                args.distributed_init_method,
            ))
            dist.init_process_group(
                backend=args.distributed_backend,
                init_method=args.distributed_init_method,
                world_size=args.distributed_world_size,
                rank=args.distributed_rank,
            )
            logger.info('initialized host {} as rank {}'.format(
                socket.gethostname(),
                args.distributed_rank,
            ))

            # perform a dummy all-reduce to initialize the NCCL communicator
            if torch.cuda.is_available():
                dist.all_reduce(torch.zeros(1).cuda())

        args.distributed_rank = torch.distributed.get_rank()
    else:
        import torch_xla.core.xla_model as xm
        assert xm.xrt_world_size() == args.distributed_world_size
        args.device_id = xm.get_local_ordinal()
        args.distributed_rank = xm.get_ordinal()
        xm.rendezvous('distributed_init')  # wait for all workers
        xm.mark_step()

    if is_master(args):
        logging.getLogger().setLevel(logging.INFO)
    else:
        logging.getLogger().setLevel(logging.WARNING)

    if args.model_parallel_size > 1:
        try:
            from fairseq.model_parallel.megatron.mpu import (
                get_model_parallel_rank,
                initialize_model_parallel,
                model_parallel_cuda_manual_seed,
            )
        except ImportError:
            raise ImportError('\n\nPlease install the megatron submodule:'
                              '\n\n  git submodule update --init '
                              'fairseq/model_parallel/megatron')
        initialize_model_parallel(args.model_parallel_size)
        model_parallel_cuda_manual_seed(args.seed)
        model_part_number = get_model_parallel_rank()
        args.checkpoint_suffix += '-model_part-{0}'.format(model_part_number)
    return args.distributed_rank
Ejemplo n.º 7
0
    def valid_step(self, sample, raise_oom=False):
        """Do forward pass in evaluation mode."""
        if self.tpu:
            import torch_xla.core.xla_model as xm

            xm.rendezvous("valid_step")  # wait for all workers
            xm.mark_step()

        with torch.no_grad():
            self.model.eval()
            self.criterion.eval()

            sample = self._prepare_sample(sample)
            if sample is None:
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                if self._dummy_batch == "DUMMY":
                    self._dummy_batch = sample
                is_dummy_batch = False

            try:
                _loss, sample_size, logging_output = self.task.valid_step(
                    sample, self.model, self.criterion)
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if not raise_oom:
                        logger.warning(
                            "ran out of memory in validation step, retrying batch"
                        )
                        for p in self.model.parameters():
                            if p.grad is not None:
                                p.grad = None  # free some memory
                        if self.cuda:
                            torch.cuda.empty_cache()
                        return self.valid_step(sample, raise_oom=True)
                raise e

            logging_outputs = [logging_output]
            if is_dummy_batch:
                if torch.is_tensor(sample_size):
                    sample_size.zero_()
                else:
                    sample_size *= 0.0

        # gather logging outputs from all replicas
        if self.data_parallel_world_size > 1:
            logging_outputs, (sample_size, ) = self._aggregate_logging_outputs(
                logging_outputs,
                sample_size,
                ignore=is_dummy_batch,
            )

        # log validation stats
        logging_output = self._reduce_and_log_stats(logging_outputs,
                                                    sample_size)

        return logging_output
Ejemplo n.º 8
0
 def __iter__(self):
     if self.rng_types is not None:
         synchronize_rng_states(self.rng_types, self.generator)
     state = AcceleratorState()
     for batch in super().__iter__():
         if state.distributed_type == DistributedType.TPU:
             xm.mark_step()
         yield batch if self.device is None else send_to_device(
             batch, self.device)
Ejemplo n.º 9
0
    def _test_optimizer(self,
                        syncfree_optim_cls,
                        ref_optim_cls,
                        optim_kwargs={'lr': 1e-2}):
        device = xm.xla_device()
        loss_fn = nn.NLLLoss()
        # syncfree model
        torch.manual_seed(0)
        syncfree_model = MNIST().train().to(device)
        syncfree_optimizer = syncfree_optim_cls(syncfree_model.parameters(),
                                                **optim_kwargs)
        # reference model
        torch.manual_seed(0)
        ref_model = MNIST().train().to(device)
        ref_optimizer = ref_optim_cls(ref_model.parameters(), **optim_kwargs)
        # fake data
        data = torch.rand(32, 1, 28, 28).to(device)
        target = torch.zeros(32).to(device)
        # training loop
        for i in range(10):
            # syncfree step
            syncfree_optimizer.zero_grad()
            syncfree_output = syncfree_model(data)
            syncfree_loss = loss_fn(syncfree_output, target)
            syncfree_loss.backward()
            # mimick nan in the gradients
            if i % 2 == 0:
                xm._fetch_gradients(syncfree_optimizer)[0].mul_(torch.nan)
                found_inf = torch.tensor(1.0).to(device)
            else:
                found_inf = torch.tensor(0.0).to(device)
            xm.optimizer_step(syncfree_optimizer,
                              optimizer_args={"found_inf": found_inf})
            xm.mark_step()
            # reference step
            ref_optimizer.zero_grad()
            ref_output = ref_model(data)
            ref_loss = loss_fn(ref_output, target)
            ref_loss.backward()
            # mimick the effect of found_inf tensor
            if i % 2 != 0:
                xm.optimizer_step(ref_optimizer)
            xm.mark_step()
            # check loss
            np.testing.assert_allclose(ref_loss.cpu().detach().numpy(),
                                       syncfree_loss.cpu().detach().numpy(),
                                       rtol=1e-2,
                                       atol=1e-2)

        # check weight
        for p, p_ref in zip(syncfree_model.parameters(),
                            ref_model.parameters()):
            np.testing.assert_allclose(p.cpu().detach().numpy(),
                                       p_ref.cpu().detach().numpy(),
                                       rtol=1e-2,
                                       atol=1e-2)
Ejemplo n.º 10
0
 def next(self):
   if self._mark_step_batch_count == self._batches_yielded:
     self._batches_yielded = 0
     xm.mark_step()
   else:
     self._batches_yielded += 1
   item = self._loader.next_item(self._device)
   if item is None:
     raise StopIteration
   return item
Ejemplo n.º 11
0
 def step(self):
     """
         Takes optimizer step
     """
     self.optimizer.param_groups[0]['lr'] = self.get_lr()
     if self.device.type == 'xla':
         xm.optimizer_step(self.optimizer)
         xm.mark_step()
     else:
         self.optimizer.step()
 def test_allgather(self):
   device = xm.xla_device()
   tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
   pg_xla = get_process_group_xla(rank=3, size=8)
   output_tensors = [torch.zeros_like(tensor)] * 8
   all_gather_pattern = r'%all\-gather\.\d+ = .+ all\-gather\('
   pg_xla.allgather([output_tensors], [tensor])
   hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors)
   hlo_matches(hlo, all_gather_pattern)
   # purge all computations attached the device.
   xm.mark_step()
Ejemplo n.º 13
0
def xla_run_method(rank, fit_method, learner_args, add_args, fit_args,
                   ctrl_args):
    "run fit method on spawned process"
    sync_valid = True
    learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args,
                                     ctrl_args)
    fit_args = setup_fit_cbs(rank, fit_args)
    fit_method(learner, **fit_args)
    xm.rendezvous('xla_run_method')
    learner.save('_xla_tmp_model', rendezvous=False)
    xm.mark_step()
Ejemplo n.º 14
0
def tpu_data_loader(args, itr):
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    xm.rendezvous('tpu_data_loader')  # wait for all workers
    xm.mark_step()
    device = utils.get_tpu_device(args)
    return iterators.CountingIterator(
        pl.ParallelLoader(itr, [device]).per_device_loader(device),
        start=getattr(itr, 'n', 0),
        total=len(itr),
    )
Ejemplo n.º 15
0
    def train_loop_fn(loader, epoch):
        if FLAGS.fine_grained_metrics:
            epoch_start_time = time.time()
            step_latency_tracker, bwd_latency_tracker, fwd_latency_tracker = [], [], []
        else:
            tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            if FLAGS.fine_grained_metrics:
                step_start_time = time.time()
            optimizer.zero_grad()
            if FLAGS.fine_grained_metrics:
                fwd_start_time = time.time()
            with autocast():
                output = model(data)
                loss = loss_fn(output, target)
            if FLAGS.fine_grained_metrics:
                fwd_end_time = time.time()
                fwd_latency = fwd_end_time - fwd_start_time

                bwd_start_time = time.time()
            scaler.scale(loss).backward()
            gradients = xm._fetch_gradients(optimizer)
            xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
            scaler.step(optimizer)
            scaler.update()
            xm.mark_step()
            if lr_scheduler:
                lr_scheduler.step()
            if FLAGS.fine_grained_metrics:
                bwd_end_time = time.time()
                bwd_latency = bwd_end_time - bwd_start_time

                step_latency = bwd_end_time - step_start_time
                step_latency_tracker.append(step_latency)
                bwd_latency_tracker.append(bwd_latency)
                fwd_latency_tracker.append(fwd_latency)
            else:
                tracker.add(FLAGS.batch_size)
            if step % FLAGS.log_steps == 0:
                if FLAGS.fine_grained_metrics:
                    print('FineGrainedMetrics :: Epoch={} Step={} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\
                                                epoch, step, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
                else:
                    # _train_update(device, step, loss, tracker, epoch, writer)
                    xm.add_step_closure(_train_update,
                                        args=(device, step, loss, tracker,
                                              epoch, writer))
        if FLAGS.fine_grained_metrics:
            epoch_end_time = time.time()
            epoch_latency = epoch_end_time - epoch_start_time
            print('FineGrainedMetrics :: Epoch={} Epoch(s)={:.} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\
                                            epoch, epoch_latency, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
 def test_allreduce(self):
   device = xm.xla_device()
   tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
   pg_xla = get_process_group_xla(rank=511, size=1024)
   opts = dist.AllreduceOptions()
   opts.reduceOp = dist.ReduceOp.SUM
   all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\('
   with xm_cc_op_intercepted('all_reduce'):
     pg_xla.allreduce([tensor], opts)
   hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
   hlo_matches(hlo, all_reduce_pattern)
   # purge all computations attached the device.
   xm.mark_step()
Ejemplo n.º 17
0
def tpu_data_loader(itr):
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    from fairseq.data import iterators

    xm.rendezvous("tpu_data_loader")  # wait for all workers
    xm.mark_step()
    device = xm.xla_device()
    return iterators.CountingIterator(
        pl.ParallelLoader(itr, [device]).per_device_loader(device),
        start=getattr(itr, "n", 0),
        total=len(itr),
    )
 def test_reduce_scatter(self):
   device = xm.xla_device()
   tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
   input_list = [tensor]
   output = torch.zeros_like(tensor)
   pg_xla = get_process_group_xla(rank=0, size=len(input_list))
   opts = dist.ReduceScatterOptions()
   opts.reduceOp = dist.ReduceOp.SUM
   reduce_scatter_pattern = r'%reduce\-scatter\.\d+ = .+ reduce\-scatter\('
   pg_xla.reduce_scatter([output], [input_list], opts)
   hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
   hlo_matches(hlo, reduce_scatter_pattern)
   # purge all computations attached the device.
   xm.mark_step()
Ejemplo n.º 19
0
    def test_asynchronous(self):
        flag = Event()
        assert not flag.is_set()

        def closure():
            sleep(1)
            assert flag.is_set()

        xm.add_step_closure(closure, run_async=True)
        xm.mark_step()

        # should get to this part and complete before closure is finished running
        assert not flag.is_set()
        flag.set()
Ejemplo n.º 20
0
    def test_synchronous(self):
        flag = Event()
        assert not flag.is_set()

        def closure():
            sleep(1)
            assert not flag.is_set()
            flag.set()

        xm.add_step_closure(closure)
        xm.mark_step()

        # should not get to this part before closure is finished running
        assert flag.is_set()
Ejemplo n.º 21
0
    def begin_epoch(self, epoch):
        """Called at the beginning of each epoch."""
        logger.info("begin training epoch {}".format(epoch))

        if self.quantizer is not None:
            self.quantizer.begin_epoch(epoch)

        # task specific setup per epoch
        self.task.begin_epoch(epoch, self.get_model())

        if self.tpu:
            import torch_xla.core.xla_model as xm

            xm.rendezvous('begin_epoch')  # wait for all workers
            xm.mark_step()
Ejemplo n.º 22
0
    def next(self):
        if xp.get_tracer_marked_step():
            xp.set_tracer_marked_step(False)
            self._batches_yielded += 1
        else:
            if self._mark_step_batch_count <= self._batches_yielded:
                self._batches_yielded = 0
                xm.mark_step()
            else:
                self._batches_yielded += 1

        item = self._loader.next_item(self._device)
        if item is None:
            raise StopIteration
        return item
 def test_broadcast(self):
   device = xm.xla_device()
   tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
   pg_xla = get_process_group_xla(rank=0, size=8)
   opts = dist.BroadcastOptions()
   opts.rootRank = 0
   opts.rootTensor = 0
   # xla doesn't have broadcast. We use all_reduce to implement broadcast.
   all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\('
   with xm_cc_op_intercepted('all_reduce'):
     pg_xla.broadcast([tensor], opts)
   hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
   hlo_matches(hlo, all_reduce_pattern)
   # purge all computations attached the device.
   xm.mark_step()
Ejemplo n.º 24
0
    def test_synchronous_exception(self):
        flag = Event()
        assert not flag.is_set()

        try:

            def closure():
                flag.set()
                raise RuntimeError("Simulating exception in closure")

            xm.add_step_closure(closure)
            xm.mark_step()

            assert False  # Should not reach here
        except RuntimeError as e:
            assert flag.is_set(), "Should have caught exception from closure"
Ejemplo n.º 25
0
def loop_with_amp(model, input, positions, target, causal_mask, optimizer,
                  xla_enabled, autocast, scaler):
    with autocast():
        loss = model(input, positions, target, batch_mask=causal_mask)
    if xla_enabled:
        scaler.scale(loss).backward()
        gradients = xm._fetch_gradients(optimizer)
        xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
        scaler.step(optimizer)
        scaler.update()
        xm.mark_step()
    else:
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    return loss
Ejemplo n.º 26
0
 def train_loop_fn(loader):
     tracker = xm.RateTracker()
     model.train()
     for step, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         with autocast():
             output = model(data)
             loss = loss_fn(output, target)
         scaler.scale(loss).backward()
         gradients = xm._fetch_gradients(optimizer)
         xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
         scaler.step(optimizer)
         scaler.update()
         xm.mark_step()
         tracker.add(flags.batch_size)
         if step % flags.log_steps == 0:
             xm.add_step_closure(_train_update,
                                 args=(device, step, loss, tracker, writer))
Ejemplo n.º 27
0
def broadcast_xla_master_model_param(model):
    logger.info(
        "Broadcasting XLA model parameters and buffers from master process ..."
    )

    parameters_and_buffers = []
    for p in chain(model.parameters(), model.buffers()):
        # Set all params in non-master devices to zero so that all_reduce is equivalent
        # to broadcasting parameters from master to other devices.
        if not is_main():
            zero = torch.tensor(0, dtype=p.data.dtype, device=p.data.device)
            p.data.mul_(zero)
        parameters_and_buffers.append(p.data)
    xm.wait_device_ops()
    xm.all_reduce(xm.REDUCE_SUM, parameters_and_buffers)
    xm.mark_step()
    xm.rendezvous("mmf.trainers.core.device.broadcast_xla_master_model_param")
    logger.info("Done!")
Ejemplo n.º 28
0
def loop_with_amp(model, input_ids, attention_mask, labels, optim, xla_enabled, autocast, scaler):
    with autocast():
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
    
    if xla_enabled:
        scaler.scale(loss).backward()
        gradients = xm._fetch_gradients(optim)
        xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
        scaler.step(optim)
        scaler.update()
        xm.mark_step()
    else:
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()

    return loss, optim
Ejemplo n.º 29
0
def xla_run_inference(rank, learner_args, add_args, inference_args, ctrl_args):
    sync_valid = True
    learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args,
                                     ctrl_args)
    pred_args, master_cbs = setup_inference_args(rank, inference_args)

    if rank == 0 and len(master_cbs) > 0:
        learner.add_cbs(master_cbs)

    # learner.synced_cancel.before_fit()

    if rank == 0:
        learner.sync_recorder.orig_logger = learner.logger

    results = learner.inner_get_preds(**pred_args)
    xm.rendezvous('xla_run_inference')

    save_pred_results(rank, results)
    xm.mark_step()
  def test_allreduce_with_mesh(self):
    device = xm.xla_device()
    tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()

    set_world_size(6)
    ranks = [2, 3]
    world_rank = 3
    set_world_rank(world_rank)
    with new_group_barrier_disabled():
      new_pg = dist.new_group(ranks=ranks)
    opts = dist.AllreduceOptions()
    opts.reduceOp = dist.ReduceOp.SUM
    all_reduce_pattern = (r'%all\-reduce\.\d+ = .+ all\-reduce\(.+\), .*'
                          r'replica_groups=\{\{0,1\},\{2,3\},\{4,5\}\}')
    with xm_cc_op_intercepted('all_reduce'):
      new_pg.allreduce([tensor], opts)
    hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
    hlo_matches(hlo, all_reduce_pattern)
    # purge all computations attached the device.
    xm.mark_step()