def forward(self, x): with xp.Trace('conv1'): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = self.bn1(x) with xp.Trace('conv2'): x = F.relu(F.max_pool2d(self.conv2(x), 2)) x = self.bn2(x) with xp.Trace('dense'): x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1)
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): if dynamic_graph: # testing purpose only: dynamic batch size and graph. index = max(-step, -flags.batch_size + 1) # non-empty data, target = data[:-index, :, :, :], target[:-index] if step >= 15 and training_started: # testing purpose only: set event for synchronization. training_started.set() with xp.StepTrace('train_mnist', step_num=step): with xp.Trace('build_graph'): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) if fetch_often: # testing purpose only: fetch XLA tensors to CPU. loss_i = loss.item() tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer))
def start(self, action_name: str) -> None: if action_name in self.RECORD_FUNCTIONS: if not self._start_trace: self.server = xp.start_server(self.port) self._start_trace = True if action_name in self.STEP_FUNCTIONS: step = self._get_step_num(action_name) recording = xp.StepTrace(action_name, step_num=step) else: recording = xp.Trace(action_name) recording.__enter__() self._recording_map[action_name] = recording
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): with xp.StepTrace('train_mnist', step_num=step): with xp.Trace('build_graph'): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer))