def build_dataloader_and_sampler( dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: """Builds and returns a dataloader along with its sample Args: dataset_instance (torch.utils.data.Dataset): Instance of dataset for which dataloader has to be created datamodule_config (omegaconf.DictConfig): Datamodule configuration; required for infering params for dataloader Returns: Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: Tuple of Dataloader and Sampler instance """ from mmf.common.batch_collator import BatchCollator training_config = get_global_config("training") # Support params coming in from dataloader params other_args = { "num_workers": datamodule_config.get( "num_workers", training_config.get("num_workers", 4) ), "pin_memory": datamodule_config.get( "pin_memory", training_config.get("pin_memory", False) ), "shuffle": datamodule_config.get("shuffle", None), "batch_size": datamodule_config.get("batch_size", None), } # IterableDataset returns batches directly, so no need to add Sampler # or batch size as user is expected to control those. This is a fine # assumption for now to not support single item based IterableDataset # as it will add unnecessary complexity and config parameters # to the codebase if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) else: other_args.pop("shuffle") loader = torch.utils.data.DataLoader( dataset=dataset_instance, collate_fn=BatchCollator( dataset_instance.dataset_name, dataset_instance.dataset_type ), drop_last=is_xla(), # see also MultiDatasetLoader.__len__ **other_args, ) if is_xla(): device = xm.xla_device() loader = xla_pl.MpDeviceLoader(loader, device) if other_args["num_workers"] >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = dataset_instance.dataset_type return loader, other_args.get("sampler", None)
def main(args): torch.manual_seed(args.seed) device = xm.xla_device() loader_kwargs = { 'num_workers': args.num_workers, 'batch_size': args.batch_size, } transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform) val_dataset = datasets.MNIST('data', train=False, transform=transform) if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) else: train_sampler = torch.utils.data.RandomSampler(train_dataset) train_loader = torch.utils.data.DataLoader(train_dataset, sampler=train_sampler, **loader_kwargs) train_loader = pl.MpDeviceLoader(train_loader, device) test_loader = torch.utils.data.DataLoader(val_dataset, **loader_kwargs) test_loader = pl.MpDeviceLoader(test_loader, device) model = Net().to(device) # Scale learning rate to world size lr = args.learning_rate * xm.xrt_world_size() optimizer = optim.Adam(model.parameters(), lr=lr) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) for epoch in range(1, args.num_epochs + 1): train_one_epoch(args, model, device, train_loader, optimizer, epoch) validate(model, device, test_loader) scheduler.step() if args.save_model: torch.save(model.state_dict(), "mnist_cnn.pt")
def build_dataloader_and_sampler( dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: """Builds and returns a dataloader along with its sample Args: dataset_instance (torch.utils.data.Dataset): Instance of dataset for which dataloader has to be created datamodule_config (omegaconf.DictConfig): Datamodule configuration; required for infering params for dataloader Returns: Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: Tuple of Dataloader and Sampler instance """ from mmf.common.batch_collator import BatchCollator training_config = get_global_config("training") # Support params coming in from dataloader params other_args = { "num_workers": datamodule_config.get( "num_workers", training_config.get("num_workers", 4) ), "pin_memory": datamodule_config.get( "pin_memory", training_config.get("pin_memory", False) ), "shuffle": datamodule_config.get("shuffle", None), "batch_size": datamodule_config.get("batch_size", None), } if version.parse(torch.__version__) >= version.parse("1.8"): # only use persistent workers in PyTorch 1.8 or higher # (PyTorch 1.7 also has this option but doesn't support it correctly due to # https://github.com/pytorch/pytorch/issues/48370) other_args["persistent_workers"] = ( datamodule_config.get( "persistent_workers", training_config.get("persistent_workers", True) ), ) if other_args["persistent_workers"] and other_args["num_workers"] == 0: logger.warning( "persistent_workers cannot be used together with num_workers == 0; " "setting persistent_workers to False" ) other_args["persistent_workers"] = False # IterableDataset returns batches directly, so no need to add Sampler # or batch size as user is expected to control those. This is a fine # assumption for now to not support single item based IterableDataset # as it will add unnecessary complexity and config parameters # to the codebase if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) else: other_args.pop("shuffle") # Set drop_last=True when using XLA to have constant batch size. # In this case we also need to set drop_last=True in DistributedSampler. loader = torch.utils.data.DataLoader( dataset=dataset_instance, collate_fn=BatchCollator( dataset_instance.dataset_name, dataset_instance.dataset_type ), drop_last=is_xla(), # see also MultiDatasetLoader.__len__ **other_args, ) if is_xla(): device = xm.xla_device() loader = xla_pl.MpDeviceLoader(loader, device) if other_args["num_workers"] >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = dataset_instance.dataset_type return loader, other_args.get("sampler", None)
def train_bert(dataset_path, xla_enabled, amp_enabled): max_seq_length = 128 batch_size = 16 num_epochs = 1 tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # model = BERT() model = BERTdownsized() dat = pd.read_csv(dataset_path) print(dat.head) X = dat['review'] y = dat['sentiment'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.10, random_state=42) X_train = X_train.values.tolist() X_test = X_test.values.tolist() y_train = pd.get_dummies(y_train).values.tolist() y_test = pd.get_dummies(y_test).values.tolist() train_lists = [X_train, y_train] test_lists = [X_test, y_test] training_dataset = text_dataset(x_y_list=train_lists, max_seq_length=max_seq_length, tokenizer=tokenizer) test_dataset = text_dataset(x_y_list=test_lists, max_seq_length=max_seq_length, tokenizer=tokenizer) dataloaders_dict = { 'train': torch.utils.data.DataLoader(training_dataset, batch_size=batch_size, shuffle=True, num_workers=0), 'val': torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=0) } dataset_sizes = {'train': len(train_lists[0]), 'val': len(test_lists[0])} if xla_enabled: device = xm.xla_device() else: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) lrlast = 1e-3 model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=lrlast) # scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) print('==> Starting Training') if amp_enabled: autocast, scaler = get_autocast_and_scaler(xla_enabled) if xla_enabled: import torch_xla.distributed.parallel_loader as pl server = xp.start_server(port_number) train_device_loader = pl.MpDeviceLoader(dataloaders_dict['train'], device) # train_device_loader = dataloaders_dict['train'] else: train_device_loader = dataloaders_dict['train'] if dlprof_enabled and not xla_enabled and False: with torch.autograd.profiler.emit_nvtx(): for epoch in range(num_epochs): epoch_time = time.time() # tracker = xm.RateTracker() print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) model.train() # Set model to training mode # Iterate over data. for step, (inputs, sentiment) in enumerate(train_device_loader): tracker = xm.RateTracker( ) # Placing the tracker here frees it of I/O time. if not xla_enabled: # This section is not necessary (but doesn't cause any performance problems) for XLA inputs = inputs.to(device) sentiment = sentiment.to(device) optimizer.zero_grad() if amp_enabled: loss, optimizer = loop_with_amp( model, inputs, sentiment, optimizer, xla_enabled, autocast, scaler) else: loss, optimizer = loop_without_amp( model, inputs, sentiment, optimizer, xla_enabled) tracker.add(inputs.size(0)) _train_update(device, step, loss, tracker, epoch, None) time_elapsed = time.time() - epoch_time print( f'Epoch complete in {time_elapsed // 60}m {time_elapsed % 60}s' ) else: for epoch in range(num_epochs): epoch_time = time.time() # tracker = xm.RateTracker() print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) model.train() # Set model to training mode # Iterate over data. if cpu_mem_usage: import resource print( f" CPU Usage Before: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}" ) for step, (inputs, sentiment) in enumerate(train_device_loader): if step == 5: training_started.set() tracker = xm.RateTracker( ) # Placing the tracker here frees it of I/O time. if not xla_enabled: # This section is not necessary (but doesn't cause any performance problems) for XLA inputs = inputs.to(device) sentiment = sentiment.to(device) optimizer.zero_grad() if amp_enabled: loss, optimizer = loop_with_amp(model, inputs, sentiment, optimizer, xla_enabled, autocast, scaler) else: loss, optimizer = loop_without_amp(model, inputs, sentiment, optimizer, xla_enabled) tracker.add(inputs.size(0)) _train_update(device, step, loss, tracker, epoch, None) time_elapsed = time.time() - epoch_time print( f'Epoch complete in {time_elapsed // 60}m {time_elapsed % 60}s' ) if xla_enabled and debug_enabled: import torch_xla.debug.metrics as met print(met.metrics_report())
def build_dataloader_and_sampler( dataset_instance: mmf_typings.DatasetType, training_config: mmf_typings.DictConfig ) -> mmf_typings.DataLoaderAndSampler: """Builds and returns a dataloader along with its sample Args: dataset_instance (mmf_typings.DatasetType): Instance of dataset for which dataloader has to be created training_config (mmf_typings.DictConfig): Training configuration; required for infering params for dataloader Returns: mmf_typings.DataLoaderAndSampler: Tuple of Dataloader and Sampler instance """ from mmf.common.batch_collator import BatchCollator num_workers = training_config.num_workers pin_memory = training_config.pin_memory other_args = {} # IterableDataset returns batches directly, so no need to add Sampler # or batch size as user is expected to control those. This is a fine # assumption for now to not support single item based IterableDataset # as it will add unnecessary complexity and config parameters # to the codebase if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) if is_xla(): other_args["sampler"] = torch.utils.data.DistributedSampler( dataset_instance, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=other_args["shuffle"], ) other_args.pop("shuffle") loader = torch.utils.data.DataLoader( dataset=dataset_instance, pin_memory=pin_memory, collate_fn=BatchCollator(dataset_instance.dataset_name, dataset_instance.dataset_type), num_workers=num_workers, drop_last=False, # see also MultiDatasetLoader.__len__ **other_args, ) if is_xla(): device = xm.xla_device() loader = pl.MpDeviceLoader(loader, device) if num_workers >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = dataset_instance.dataset_type return loader, other_args.get("sampler", None)
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler, test_sampler = None, None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, persistent_workers=True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, drop_last=FLAGS.drop_last, shuffle=False, persistent_workers=True, num_workers=FLAGS.num_workers) torch.manual_seed(42) device = xm.xla_device() model = get_model_property('model_fn')() # Wrap the model with FSDP # You may wrap all, a subset, or none of the sub-modules with inner FSDPs # - to implement ZeRO-2, wrap none of the sub-modules # - to implement ZeRO-3, wrap all of the sub-modules (nested FSDP) # - you may wrap sub-modules at different granularity (e.g. at each resnet # stage or each residual block or each conv layer). fsdp_wrap = lambda m: FSDP(m.to(device), compute_dtype=getattr(torch, FLAGS.compute_dtype ), fp32_reduce_scatter=FLAGS.fp32_reduce_scatter, flatten_parameters=FLAGS.flatten_parameters) # Apply gradient checkpointing to sub-modules if specified grad_ckpt_wrap = checkpoint_module if FLAGS.use_gradient_checkpointing else ( lambda x: x) if FLAGS.use_nested_fsdp: # Here we apply inner FSDP at the level of child modules for ZeRO-3, which # corresponds to different stages in resnet (i.e. Stage 1 to 5). for submodule_name, submodule in model.named_children(): if sum(p.numel() for p in submodule.parameters()) == 0: # Skip those submodules without parameters (i.e. no need to shard them) continue # Note: wrap with `checkpoint_module` first BEFORE wrapping with FSDP m_fsdp = fsdp_wrap(grad_ckpt_wrap(getattr(model, submodule_name))) setattr(model, submodule_name, m_fsdp) # Always wrap the base model with an outer FSDP model = fsdp_wrap(model) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.WarmupAndExponentialDecayScheduler( optimizer, num_steps_per_epoch=num_training_steps_per_epoch, divide_every_n_epochs=FLAGS.lr_scheduler_divide_every_n_epochs, divisor=FLAGS.lr_scheduler_divisor, num_warmup_epochs=FLAGS.num_warmup_epochs, summary_writer=writer) loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() optimizer.step() # do not reduce gradients on sharded params tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if step % FLAGS.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer)) def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] if step % FLAGS.log_steps == 0: xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step)) accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, FLAGS.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) run_eval = ((not FLAGS.test_only_at_end and epoch % FLAGS.eval_interval == 0) or epoch == FLAGS.num_epochs) if run_eval: accuracy = test_loop_fn(test_device_loader, epoch) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def training(rank, world_size, backend, config): # Specific xla print(xm.get_ordinal(), ": run with config:", config, "- backend=", backend) device = xm.xla_device() # Data preparation dataset = RndDataset(nb_samples=config["nb_samples"]) # Specific xla train_sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), ) train_loader = torch.utils.data.DataLoader( dataset, batch_size=int(config["batch_size"] / xm.xrt_world_size()), num_workers=1, sampler=train_sampler, ) # Specific xla para_loader = pl.MpDeviceLoader(train_loader, device) # Model, criterion, optimizer setup model = wide_resnet50_2(num_classes=100).to(device) criterion = NLLLoss() optimizer = SGD(model.parameters(), lr=0.01) # Training loop log param log_interval = config["log_interval"] def _train_step(batch_idx, data, target): data = data target = target optimizer.zero_grad() output = model(data) # Add a softmax layer probabilities = torch.nn.functional.softmax(output, dim=0) loss_val = criterion(probabilities, target) loss_val.backward() xm.optimizer_step(optimizer) if batch_idx % log_interval == 0: print( "Process {}/{} Train Epoch: {} [{}/{}]\tLoss: {}".format( xm.get_ordinal(), xm.xrt_world_size(), epoch, batch_idx * len(data), len(train_sampler), loss_val.item(), ) ) return loss_val # Running _train_step for n_epochs n_epochs = 1 for epoch in range(n_epochs): for batch_idx, (data, target) in enumerate(para_loader): _train_step(batch_idx, data, target)
def train_mnist(flags, training_started=None, dynamic_graph=False, fetch_often=False): torch.manual_seed(1) if flags.fake_data: train_loader = xu.SampleGenerator( data=( torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64), ), sample_count=600000 // flags.batch_size // xm.xrt_world_size(), ) test_loader = xu.SampleGenerator( data=( torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64), ), sample_count=100000 // flags.batch_size // xm.xrt_world_size(), ) else: train_dataset = datasets.MNIST( os.path.join(flags.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ), ) test_dataset = datasets.MNIST( os.path.join(flags.datadir, str(xm.get_ordinal())), train=False, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ), ) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=flags.batch_size, sampler=train_sampler, drop_last=flags.drop_last, shuffle=False if train_sampler else True, num_workers=flags.num_workers, ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=flags.batch_size, drop_last=flags.drop_last, shuffle=False, num_workers=flags.num_workers, ) # Scale learning rate to num cores lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() server = xp.start_server(flags.profiler_port) def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): if dynamic_graph: # testing purpose only: dynamic batch size and graph. index = max(-step, -flags.batch_size + 1) # non-empty data, target = data[:-index, :, :, :], target[:-index] if step >= 15 and training_started: # testing purpose only: set event for synchronization. training_started.set() with xp.StepTrace("train_mnist", step_num=step): with xp.Trace("build_graph"): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) if fetch_often: # testing purpose only: fetch XLA tensors to CPU. loss_i = loss.item() tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer)) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: with xp.StepTrace("test_mnist"): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce("test_accuracy", accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print("Epoch {} train begin {}".format(epoch, test_utils.now())) train_loop_fn(train_device_loader) xm.master_print("Epoch {} train end {}".format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) xm.master_print( "Epoch {} test end {}, Accuracy={:.2f}".format(epoch, test_utils.now(), accuracy) ) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, epoch, dict_to_write={"Accuracy/test": accuracy}, write_xla_metrics=True ) if flags.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print("Max Accuracy: {:.2f}%".format(max_accuracy)) return max_accuracy
def train_imagenet(): print("==> Preparing data..") img_dim = get_model_property("img_dim") if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64), ), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size(), ) if FLAGS.validate: test_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64), ), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size(), ) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, "train"), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]), ) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) if FLAGS.validate: test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, "val"), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ]), ) train_sampler, test_sampler = None, None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) if FLAGS.validate: test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers, ) if FLAGS.validate: test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, drop_last=FLAGS.drop_last, shuffle=False, num_workers=FLAGS.num_workers, ) device = xm.xla_device() model = get_model_property("model_fn")().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, "lr_scheduler_type", None), scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None), scheduler_divide_every_n_epochs=getattr( FLAGS, "lr_scheduler_divide_every_n_epochs", None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer, ) loss_fn = nn.CrossEntropyLoss() scaler = GradScaler() def train_loop_fn(loader, epoch): if FLAGS.fine_grained_metrics: epoch_start_time = time.time() step_latency_tracker, bwd_latency_tracker, fwd_latency_tracker = [], [], [] else: tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): if FLAGS.fine_grained_metrics: step_start_time = time.time() optimizer.zero_grad() if FLAGS.fine_grained_metrics: fwd_start_time = time.time() with autocast(): output = model(data) loss = loss_fn(output, target) if FLAGS.fine_grained_metrics: fwd_end_time = time.time() fwd_latency = fwd_end_time - fwd_start_time bwd_start_time = time.time() scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() if lr_scheduler: lr_scheduler.step() if FLAGS.fine_grained_metrics: bwd_end_time = time.time() bwd_latency = bwd_end_time - bwd_start_time step_latency = bwd_end_time - step_start_time step_latency_tracker.append(step_latency) bwd_latency_tracker.append(bwd_latency) fwd_latency_tracker.append(fwd_latency) else: tracker.add(FLAGS.batch_size) if step % FLAGS.log_steps == 0: if FLAGS.fine_grained_metrics: print('FineGrainedMetrics :: Epoch={} Step={} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\ epoch, step, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker))) else: # _train_update(device, step, loss, tracker, epoch, writer) xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer)) if FLAGS.fine_grained_metrics: epoch_end_time = time.time() epoch_latency = epoch_end_time - epoch_start_time print('FineGrainedMetrics :: Epoch={} Epoch(s)={:.} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\ epoch, epoch_latency, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker))) def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] if step % FLAGS.log_steps == 0: test_utils.print_test_update(device, None, epoch, step) # xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step)) accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce("test_accuracy", accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) if FLAGS.validate: test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, FLAGS.num_epochs + 1): xm.master_print("Epoch {} train begin {}".format( epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print("Epoch {} train end {}".format(epoch, test_utils.now())) if FLAGS.validate: accuracy = test_loop_fn(test_device_loader, epoch) xm.master_print("Epoch {} test end {}, Accuracy={:.2f}".format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, epoch, dict_to_write={"Accuracy/test": accuracy}, write_xla_metrics=True) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) if FLAGS.validate: xm.master_print("Max Accuracy: {:.2f}%".format(max_accuracy)) return max_accuracy if FLAGS.validate else None
def fit(self, train_loader, validation_loader): param_optimizer = list(self.model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] # Try use different LR for HEAD and EffNet # self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=config.GPU_LR) LR = self.config.TPU_LR if global_config.CONTINUE_TRAIN: # Continue training proc -> Hand-tune LR LR = self.config.TPU_LR # [9e-4, 1e-3] self.optimizer = torch.optim.AdamW([ {'params': self.model.efn.parameters(), 'lr': LR[0]}, {'params': self.model.fc1.parameters(), 'lr': LR[1]}, {'params': self.model.bn1.parameters(), 'lr': LR[1]}, {'params': self.model.dense_out.parameters(), 'lr': LR[1]} ]) ############################################## self.scheduler = self.config.SchedulerClass(self.optimizer, **self.config.scheduler_params) # num_train_steps = int(self.steps * (global_config.GPU_EPOCH)) # self.scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( # self.optimizer, # num_warmup_steps=int(num_train_steps * 0.05), # WARMUP_PROPORTION = 0.1 as default # num_training_steps=num_train_steps, # num_cycles=0.5 # ) ############################################## # DataLoader should init only once (outside the epoch loop) train_device_loader = pl.MpDeviceLoader(train_loader, xm.xla_device()) if validation_loader == 1: pass else: val_device_loader = pl.MpDeviceLoader(validation_loader, xm.xla_device()) ############################################## for e in range(self.config.TPU_EPOCH): ############## Training gc.collect() t = time.time() xm.master_print("---" * 31) summary_loss, final_scores = self.train_one_epoch(train_device_loader) effNet_lr = np.format_float_scientific(self.optimizer.param_groups[0]['lr'], unique=False, precision=1) head_lr = np.format_float_scientific(self.optimizer.param_groups[1]['lr'], unique=False, precision=1) self.log(f":::[Train RESULT]| Epoch: {str(self.epoch).rjust(2, ' ')} | Loss: {summary_loss.avg:.4f} | AUC: {final_scores.avg:.4f} | LR: {effNet_lr}/{head_lr} | Time: {int((time.time() - t)//60)}m") self.save(f'{self.base_dir}/last_ckpt.pt') ############## Validation gc.collect() t = time.time() # Skip Validation if validation_loader == 1: pass else: summary_loss, final_scores = self.validation(val_device_loader) self.log(f":::[Valid RESULT]| Epoch: {str(self.epoch).rjust(2, ' ')} | Loss: {summary_loss.avg:.4f} | AUC: {final_scores.avg:.4f} | LR: {effNet_lr}/{head_lr} | Time: {int((time.time() - t)//60)}m") if summary_loss.avg < self.best_summary_loss: self.best_summary_loss = summary_loss.avg self.model.eval() self.save(f'{self.base_dir}/{global_config.SAVED_NAME}_{str(self.epoch).zfill(3)}ep.pt') # keep only the best 3 checkpoints # for path in sorted(glob(f'{self.base_dir}/{global_config.SAVED_NAME}_*ep.pt'))[:-3]: # os.remove(path) if self.config.validation_scheduler: try: self.scheduler.step(metrics=summary_loss.avg) except: self.scheduler.step() self.epoch += 1
def train_bert(dataset_path, xla_enabled, amp_enabled): max_seq_length = 256 batch_size = 32 num_epochs = 25 tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BERT() dat = pd.read_csv(dataset_path) print(dat.head) X = dat['review'] y = dat['sentiment'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.10, random_state=42) X_train = X_train.values.tolist() X_test = X_test.values.tolist() y_train = pd.get_dummies(y_train).values.tolist() y_test = pd.get_dummies(y_test).values.tolist() train_lists = [X_train, y_train] test_lists = [X_test, y_test] training_dataset = text_dataset(x_y_list=train_lists, max_seq_length=max_seq_length, tokenizer=tokenizer) test_dataset = text_dataset(x_y_list=test_lists, max_seq_length=max_seq_length, tokenizer=tokenizer) dataloaders_dict = { 'train': torch.utils.data.DataLoader(training_dataset, batch_size=batch_size, shuffle=True, num_workers=0), 'val': torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=0) } dataset_sizes = {'train': len(train_lists[0]), 'val': len(test_lists[0])} if xla_enabled: device = xm.xla_device() else: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) lrlast = 1e-3 optimizer = optim.Adam(model.parameters(), lr=lrlast) # scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) model = model.to(device) print('==> Starting Training') if amp_enabled: autocast, scaler = get_autocast_and_scaler(xla_enabled) train_device_loader = pl.MpDeviceLoader(dataloaders_dict['train'], device) for epoch in range(num_epochs): epoch_time = time.time() tracker = xm.RateTracker() print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) model.train() # Set model to training mode # Iterate over data. for step, (inputs, sentiment) in enumerate(train_device_loader): # import pdb;pdb.set_trace() # sentiment = torch.max(sentiment.float(), 1)[1] # inputs = inputs.to(device) # sentiment = sentiment.to(device) optimizer.zero_grad() if amp_enabled: loss, optimizer = loop_with_amp(model, inputs, sentiment, optimizer, xla_enabled, autocast, scaler) else: loss, optimizer = loop_without_amp(model, inputs, sentiment, optimizer, xla_enabled) tracker.add(inputs.size(0)) _train_update(device, step, loss, tracker, epoch, None) time_elapsed = time.time() - epoch_time print(f'Epoch complete in {time_elapsed // 60}m {time_elapsed % 60}s')
def train_eval_loop( model, loss, optimizer, scheduler, train_loader, test_loader, device, epochs, verbose, save, save_freq=None, save_path=None, epoch_offset=0, **kwargs, ): print_fn = print if device.type == "xla": import torch_xla.distributed.parallel_loader as pl import torch_xla.core.xla_model as xm print_fn = xm.master_print train_loader = pl.MpDeviceLoader(train_loader, device) test_loader = pl.MpDeviceLoader(test_loader, device) test_loss, accuracy1, accuracy5 = eval(model, loss, test_loader, device, verbose, 0) metric_dict = { "train_loss": 0, "test_loss": test_loss, "accuracy1": accuracy1, "accuracy5": accuracy5, } if save: checkpoint( model, optimizer, scheduler, 0, 0, save_path, verbose, metric_dict, tpu=(device.type == "xla"), ) for epoch in tqdm(range(epoch_offset, epoch_offset + epochs)): train_loss = train( model, loss, optimizer, scheduler, train_loader, device, epoch, verbose, save, save_freq=save_freq, save_path=save_path, **kwargs, ) test_loss, accuracy1, accuracy5 = eval( model, loss, test_loader, device, verbose, epoch + 1 ) metric_dict = { "train_loss": train_loss, "test_loss": test_loss, "accuracy1": accuracy1, "accuracy5": accuracy5, } curr_step = (epoch + 1) * kwargs.get("num_batches") if save: checkpoint( model, optimizer, scheduler, epoch, curr_step, save_path, verbose, metric_dict, tpu=(device.type == "xla"), ) scheduler.step() if epochs > 0: print_fn( f"Final performance: " f"\tTrain Loss: {train_loss:.4f}" f"\tTest Loss: {test_loss:.4f}" f"\tAccuracy: {accuracy1:.2f}%" )
def _main_xla(index, args): import torch_xla import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met import torch_xla.distributed.parallel_loader as pl alphabet = alphabet_factory() train_dataset, test_dataset = split_dataset(args, alphabet) collate_fn = collate_factory(model_length_function) if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) else: train_sampler = torch.utils.data.RandomSampler(train_dataset) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=collate_fn, drop_last=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn, drop_last=True) # Scale learning rate to world size lr = args.learning_rate * xm.xrt_world_size() # Get loss function, optimizer, and model device = xm.xla_device() model = build_deepspeech(in_features=in_features, num_classes=len(alphabet)) model = model.to(device) optimizer = get_optimizer(args, model.parameters()) criterion = nn.CTCLoss(blank=alphabet.mapping[alphabet.char_blank]) decoder = GreedyDecoder() train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) class XLAProxyOptimizer: """ XLA Proxy optimizer for compatibility with torch.Optimizer """ def __init__(self, optimizer): self.optimizer = optimizer def zero_grad(self): self.optimizer.zero_grad() def step(self): xm.optimizer_step(self.optimizer) optimizer = XLAProxyOptimizer(optimizer) train_eval_fn(args.num_epochs, train_device_loader, test_device_loader, optimizer, model, criterion, device, decoder, alphabet, args.checkpoint)
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler, test_sampler = None, None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, drop_last=FLAGS.drop_last, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) device = xm.xla_device() model = get_model_property('model_fn')().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer) loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if step % FLAGS.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer)) def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] if step % FLAGS.log_steps == 0: xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step)) accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, FLAGS.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader, epoch) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def run_fold(fold): create_dirs() print_fn = print if not config.USE_TPU else xm.master_print print_fn(f"___________________________________________________") print_fn(f"Training Model: {config.NET}") print_fn(f"Training Fold: {fold}") print_fn(f"Image Dimensions: {config.H}x{config.W}") print_fn(f"Mixed Precision Training: {config.MIXED_PRECISION_TRAIN}") print_fn(f"Training Batch Size: {config.TRAIN_BATCH_SIZE}") print_fn(f"Validation Batch Size: {config.VALID_BATCH_SIZE}") print_fn(f"Accumulate Iteration: {config.ACCUMULATE_ITERATION}") global net train_loader, valid_loader = get_loaders(fold) device = get_device(n=fold + 1) net = net.to(device) scaler = torch.cuda.amp.GradScaler( ) if not config.USE_TPU and config.MIXED_PRECISION_TRAIN else None loss_tr = get_train_criterion(device=device) loss_fn = get_valid_criterion(device=device) optimizer, scheduler = get_optimizer_and_scheduler(net=net, dataloader=train_loader) gc.collect() for epoch in range(config.MAX_EPOCHS): epoch_start = time.time() if config.DO_FREEZE_BATCH_NORM and epoch < config.FREEZE_BN_EPOCHS: freeze_batchnorm_stats(net) train_mp_device_loader = pl.MpDeviceLoader( train_loader, device, fixed_batch_size=True) if config.USE_TPU else train_loader train_one_epoch(fold, epoch, net, loss_tr, optimizer, train_mp_device_loader, device, scaler=scaler, scheduler=scheduler, schd_batch_update=config.SCHEDULER_BATCH_STEP) del train_mp_device_loader gc.collect() valid_mp_device_loader = pl.MpDeviceLoader( valid_loader, device, fixed_batch_size=True) if config.USE_TPU else valid_loader valid_one_epoch(fold, epoch, net, loss_fn, valid_mp_device_loader, device, scheduler=None, schd_loss_update=False) del valid_mp_device_loader gc.collect() print_fn( f'[{fold}/{config.FOLDS - 1}][{epoch:>2d}/{config.MAX_EPOCHS - 1:>2d}] Time Taken for Epoch {epoch}: {time.time() - epoch_start} seconds |' ) if config.USE_TPU: xm.save( net.state_dict(), os.path.join( config.WEIGHTS_PATH, f'{config.NET}/{config.NET}_fold_{fold}_{epoch}.bin')) else: torch.save( net.state_dict(), os.path.join( config.WEIGHTS_PATH, f'{config.NET}/{config.NET}_fold_{fold}_{epoch}.bin')) #torch.save(model.cnn_model.state_dict(),'{}/cnn_model_fold_{}_{}'.format(CFG['model_path'], fold, CFG['tag'])) del net, optimizer, train_loader, valid_loader, scheduler torch.cuda.empty_cache() print_fn(f"___________________________________________________")
def train_mnist(flags, state_dict): if flags.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=60000 // flags.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=10000 // flags.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(os.path.join(flags.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) test_dataset = datasets.MNIST(os.path.join(flags.datadir, str(xm.get_ordinal())), train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=flags.batch_size, sampler=train_sampler, drop_last=flags.drop_last, shuffle=False if train_sampler else True, num_workers=flags.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=flags.batch_size, drop_last=flags.drop_last, shuffle=False, num_workers=flags.num_workers) # Scale learning rate to num cores lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST() model.load_state_dict(state_dict) model = model.to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer), run_async=FLAGS.async_closures) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] accuracy = 100.0 * correct.item() / total_samples # accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, test_utils.now())) train_loop_fn(train_device_loader) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if flags.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def train_mnist(flags, **kwargs): torch.manual_seed(1) if flags.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=60000 // flags.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=10000 // flags.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST( os.path.join(flags.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) test_dataset = datasets.MNIST( os.path.join(flags.datadir, str(xm.get_ordinal())), train=False, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=flags.batch_size, sampler=train_sampler, drop_last=flags.drop_last, shuffle=False if train_sampler else True, num_workers=flags.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=flags.batch_size, drop_last=flags.drop_last, shuffle=False, num_workers=flags.num_workers) # Scale learning rate to num cores lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST() # Wrap the model with FSDP fsdp_wrap = lambda m: FSDP( m.to(device), compute_dtype=getattr(torch, flags.compute_dtype), fp32_reduce_scatter=flags.fp32_reduce_scatter, flatten_parameters=flags.flatten_parameters) # Apply gradient checkpointing to sub-modules if specified grad_ckpt_wrap = checkpoint_module if flags.use_gradient_checkpointing else ( lambda x: x) if flags.use_nested_fsdp: # Wrap a few sub-modules with inner FSDP (to implement ZeRO-3) # Note: wrap with `checkpoint_module` first BEFORE wrapping with FSDP model.conv1 = fsdp_wrap(grad_ckpt_wrap(model.conv1)) model.conv2 = fsdp_wrap(grad_ckpt_wrap(model.conv2)) model.fc1 = fsdp_wrap(grad_ckpt_wrap(model.fc1)) model.fc2 = fsdp_wrap(grad_ckpt_wrap(model.fc2)) # Always wrap the base model with an outer FSDP model = fsdp_wrap(model) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() def train_loop_fn(model, loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() optimizer.step() # do not reduce gradients on sharded params tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure( _train_update, args=(device, step, loss, tracker, writer), run_async=FLAGS.async_closures) def test_loop_fn(model, loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) train_loop_fn(model, train_device_loader) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(model, test_device_loader) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if flags.metrics_debug: xm.master_print(met.metrics_report()) if flags.ckpt_consolidation: # Note: to run this test, all the model checkpoints needs to be # accessible from the master rank. Set --ckpt_prefix to a shared file # system (e.g. NFS) when running on a TPU pod. # Save the final model checkpoint rank = xm.get_ordinal() world_size = xm.xrt_world_size() ckpt_path = f'{flags.ckpt_prefix}_rank-{rank:08d}-of-{world_size:08d}.pth' ckpt = { 'model': model.state_dict(), 'shard_metadata': model.get_shard_metadata(), 'optimizer': optimizer.state_dict(), # not needed in ckpt consolidation } os.makedirs(os.path.dirname(ckpt_path), exist_ok=True) xm.save(ckpt, ckpt_path, master_only=False) print(f'checkpoint saved to {ckpt_path}\n', end='') # Consolidate the sharded model checkpoints and test its accuracy if xm.is_master_ordinal(local=False): consolidate_sharded_model_checkpoints( ckpt_prefix=flags.ckpt_prefix, ckpt_suffix="_rank-*-of-*.pth") xm.rendezvous('ckpt_consolidation') model = MNIST().to(device) ckpt_consolidated = torch.load(f'{flags.ckpt_prefix}_consolidated.pth') model.load_state_dict(ckpt_consolidated['model']) accuracy = test_loop_fn(model, test_device_loader) xm.master_print( f'Checkpoint consolidated, Accuracy={accuracy:.2f} ' '(note: it can be slightly different from the final training accuracy ' 'due to non-sync BatchNorm2d in the model)') test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy