Пример #1
0
 def forward(self, x):
     with xp.Trace('conv1'):
         x = F.relu(F.max_pool2d(self.conv1(x), 2))
         x = self.bn1(x)
     with xp.Trace('conv2'):
         x = F.relu(F.max_pool2d(self.conv2(x), 2))
         x = self.bn2(x)
     with xp.Trace('dense'):
         x = torch.flatten(x, 1)
         x = F.relu(self.fc1(x))
         x = self.fc2(x)
     return F.log_softmax(x, dim=1)
Пример #2
0
    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            if dynamic_graph:
                # testing purpose only: dynamic batch size and graph.
                index = max(-step, -flags.batch_size + 1)  # non-empty
                data, target = data[:-index, :, :, :], target[:-index]
            if step >= 15 and training_started:
                # testing purpose only: set event for synchronization.
                training_started.set()

            with xp.StepTrace('train_mnist', step_num=step):
                with xp.Trace('build_graph'):
                    optimizer.zero_grad()
                    output = model(data)
                    loss = loss_fn(output, target)
                    loss.backward()
                xm.optimizer_step(optimizer)
                if fetch_often:
                    # testing purpose only: fetch XLA tensors to CPU.
                    loss_i = loss.item()
                tracker.add(flags.batch_size)
                if step % flags.log_steps == 0:
                    xm.add_step_closure(_train_update,
                                        args=(device, step, loss, tracker,
                                              writer))
Пример #3
0
    def start(self, action_name: str) -> None:
        if action_name in self.RECORD_FUNCTIONS:
            if not self._start_trace:
                self.server = xp.start_server(self.port)
                self._start_trace = True

            if action_name in self.STEP_FUNCTIONS:
                step = self._get_step_num(action_name)
                recording = xp.StepTrace(action_name, step_num=step)
            else:
                recording = xp.Trace(action_name)
            recording.__enter__()
            self._recording_map[action_name] = recording
Пример #4
0
    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            with xp.StepTrace('train_mnist', step_num=step):
                with xp.Trace('build_graph'):
                    optimizer.zero_grad()
                    output = model(data)
                    loss = loss_fn(output, target)
                    loss.backward()
                xm.optimizer_step(optimizer)

                tracker.add(flags.batch_size)
                if step % flags.log_steps == 0:
                    xm.add_step_closure(_train_update,
                                        args=(device, step, loss, tracker,
                                              writer))