def start(self, action_name: str) -> None: if action_name in self.RECORD_FUNCTIONS: if not self._start_trace: self.server = xp.start_server(self.port) self._start_trace = True if action_name in self.STEP_FUNCTIONS: step = self._get_step_num(action_name) recording = xp.StepTrace(action_name, step_num=step) else: recording = xp.Trace(action_name) recording.__enter__() self._recording_map[action_name] = recording
def train_bert(dataset_path, xla_enabled, amp_enabled): max_seq_length = 128 batch_size = 16 num_epochs = 1 tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # model = BERT() model = BERTdownsized() dat = pd.read_csv(dataset_path) print(dat.head) X = dat['review'] y = dat['sentiment'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.10, random_state=42) X_train = X_train.values.tolist() X_test = X_test.values.tolist() y_train = pd.get_dummies(y_train).values.tolist() y_test = pd.get_dummies(y_test).values.tolist() train_lists = [X_train, y_train] test_lists = [X_test, y_test] training_dataset = text_dataset(x_y_list=train_lists, max_seq_length=max_seq_length, tokenizer=tokenizer) test_dataset = text_dataset(x_y_list=test_lists, max_seq_length=max_seq_length, tokenizer=tokenizer) dataloaders_dict = { 'train': torch.utils.data.DataLoader(training_dataset, batch_size=batch_size, shuffle=True, num_workers=0), 'val': torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=0) } dataset_sizes = {'train': len(train_lists[0]), 'val': len(test_lists[0])} if xla_enabled: device = xm.xla_device() else: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) lrlast = 1e-3 model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=lrlast) # scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) print('==> Starting Training') if amp_enabled: autocast, scaler = get_autocast_and_scaler(xla_enabled) if xla_enabled: import torch_xla.distributed.parallel_loader as pl server = xp.start_server(port_number) train_device_loader = pl.MpDeviceLoader(dataloaders_dict['train'], device) # train_device_loader = dataloaders_dict['train'] else: train_device_loader = dataloaders_dict['train'] if dlprof_enabled and not xla_enabled and False: with torch.autograd.profiler.emit_nvtx(): for epoch in range(num_epochs): epoch_time = time.time() # tracker = xm.RateTracker() print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) model.train() # Set model to training mode # Iterate over data. for step, (inputs, sentiment) in enumerate(train_device_loader): tracker = xm.RateTracker( ) # Placing the tracker here frees it of I/O time. if not xla_enabled: # This section is not necessary (but doesn't cause any performance problems) for XLA inputs = inputs.to(device) sentiment = sentiment.to(device) optimizer.zero_grad() if amp_enabled: loss, optimizer = loop_with_amp( model, inputs, sentiment, optimizer, xla_enabled, autocast, scaler) else: loss, optimizer = loop_without_amp( model, inputs, sentiment, optimizer, xla_enabled) tracker.add(inputs.size(0)) _train_update(device, step, loss, tracker, epoch, None) time_elapsed = time.time() - epoch_time print( f'Epoch complete in {time_elapsed // 60}m {time_elapsed % 60}s' ) else: for epoch in range(num_epochs): epoch_time = time.time() # tracker = xm.RateTracker() print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) model.train() # Set model to training mode # Iterate over data. if cpu_mem_usage: import resource print( f" CPU Usage Before: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}" ) for step, (inputs, sentiment) in enumerate(train_device_loader): if step == 5: training_started.set() tracker = xm.RateTracker( ) # Placing the tracker here frees it of I/O time. if not xla_enabled: # This section is not necessary (but doesn't cause any performance problems) for XLA inputs = inputs.to(device) sentiment = sentiment.to(device) optimizer.zero_grad() if amp_enabled: loss, optimizer = loop_with_amp(model, inputs, sentiment, optimizer, xla_enabled, autocast, scaler) else: loss, optimizer = loop_without_amp(model, inputs, sentiment, optimizer, xla_enabled) tracker.add(inputs.size(0)) _train_update(device, step, loss, tracker, epoch, None) time_elapsed = time.time() - epoch_time print( f'Epoch complete in {time_elapsed // 60}m {time_elapsed % 60}s' ) if xla_enabled and debug_enabled: import torch_xla.debug.metrics as met print(met.metrics_report())
def train_mnist(flags, **kwargs): torch.manual_seed(1) if flags.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=60000 // flags.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=10000 // flags.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(os.path.join(flags.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) test_dataset = datasets.MNIST(os.path.join(flags.datadir, str(xm.get_ordinal())), train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=flags.batch_size, sampler=train_sampler, drop_last=flags.drop_last, shuffle=False if train_sampler else True, num_workers=flags.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=flags.batch_size, drop_last=flags.drop_last, shuffle=False, num_workers=flags.num_workers) # Scale learning rate to num cores lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() # Start up client side profiler server. server = xp.start_server(flags.profiler_port) # Testing purpose only: set event for synchronization. if kwargs.get('worker_started'): kwargs.pop('worker_started').set() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): with xp.StepTrace('train_mnist', step_num=step): with xp.Trace('build_graph'): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer)) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: with xp.StepTrace('test_mnist'): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, test_utils.now())) train_loop_fn(train_device_loader) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if flags.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy