def test(self): devices = xm.get_xla_supported_devices() batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) train_loader = xu.SampleGenerator( data=(torch.zeros(batch_size, 3, 224, 224), torch.zeros(batch_size, dtype=torch.int64)), sample_count=sample_count * len(devices)) def loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) for data, target in loader: with xu.TimedScope(msg='Training loop: ', printfn=None): optimizer.zero_grad() output = xu.timed(lambda: model(data), msg='Model: ', printfn=None) loss = xu.timed( lambda: loss_fn(output, target), msg='Loss: ', printfn=None) xu.timed(loss.backward, msg='LossBkw: ', printfn=None) xu.timed( lambda: xm.optimizer_step(optimizer), msg='Step: ', printfn=None) self.assertLess(loss.cpu().item(), 3.0) model_parallel = dp.DataParallel( torchvision.models.resnet18, device_ids=devices) model_parallel(loop_fn, train_loader)
def train_mnist(): torch.manual_seed(1) if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=60000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(FLAGS.datadir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) test_dataset = datasets.MNIST(FLAGS.datadir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers) # Scale learning rate to num cores lr = FLAGS.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST().to(device) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy = test_loop_fn(para_loader.per_device_loader(device)) if FLAGS.metrics_debug: print(met.metrics_report()) return accuracy
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Pass [] as device_ids to run using the PyTorch/CPU engine. torchvision_model = get_model_property('model_fn') model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) lr_scheduler = context.getattr_or( 'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer if xm.is_master_ordinal() else None)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) if lr_scheduler: lr_scheduler.step() def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None num_devices = len( xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * num_devices) for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) global_step = (epoch - 1) * num_training_steps_per_epoch test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, global_step) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy
def train_mnist(FLAGS): DTYPE = torch.float32 torch.manual_seed(1) dims = ( FLAGS.batch_size, 1, 784, ) train_dataset_len = FLAGS.steps_per_epoch if FLAGS.steps_per_epoch else 60000 train_loader = xu.SampleGenerator( data=( torch.ones(dims, dtype=DTYPE,), torch.ones( FLAGS.batch_size, dtype=torch.int64 if not _MSE_LOSS else DTYPE, ), ), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size(), ) test_loader = xu.SampleGenerator( data=( torch.ones(dims, dtype=DTYPE,), torch.ones( FLAGS.batch_size, dtype=torch.int64 if not _MSE_LOSS else DTYPE, ), ), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size(), ) devices = ( xm.get_xla_supported_devices(max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [] ) """ Non multi-processing """ # Scale learning rate to num cores lr = FLAGS.lr * max(len(devices), 1) model = MNIST(FLAGS) model_parallel = dp.DataParallel( model, device_ids=devices, ) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) # Just some step closure output def train_output_fn(outputs, ctx, args, tracker): if ctx.step > 0 and args.log_steps and ctx.step % args.log_steps == 0: now_time = time.time() if hasattr(ctx, 'start_time') and ctx.start_time: per_step_time = (now_time - ctx.start_time) / ( ctx.step - ctx.last_step_timed ) steps_per_second = 1 / per_step_time print( f'[{xm.get_ordinal()}] Round-trip step time: ' f'{per_step_time} seconds, steps per second: {steps_per_second}' ) if tracker: _train_update( device=device, step=ctx.step, loss=outputs[0], tracker=tracker, epoch=epoch, writer=writer, ) print(f'BEGIN Train step {ctx.step}') ctx.start_time = time.time() ctx.last_step_timed = ctx.step else: ctx.start_time = time.time() ctx.last_step_timed = ctx.step ctx.step += 1 def train_loop_fn(model, loader, device=None, context=None): lr_adder = 0.0 if _MSE_LOSS: loss_fn = nn.MSELoss() else: loss_fn = nn.NLLLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD( model.parameters(), lr=lr + lr_adder, momentum=FLAGS.momentum, ), ) tracker = xm.RateTracker() model.train() def train_inner_loop_fn(batch, ctx): step = ctx.step print(f'Step {step}') data = batch[0] target = batch[1] optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step( optimizer, barrier=False, ) if ( FLAGS.log_steps != 0 and ( FLAGS.log_steps == 1 or (step > 0 and step % FLAGS.log_steps == 0) ) ): xm.add_step_closure( _train_update, args=(device, step, loss, tracker, epoch, writer), ) if step == 0: xm.master_print(f"End TRAIN step {step}") ctx.step += 1 return [loss] step = 0 # Train print('Starting new epoch train loop... (epoch={epoch})') for step, (data, target) in enumerate(loader): if step % FLAGS.step_print_interval == 0: xm.master_print(f"Begin TRAIN Step: {step}") context.step = step if not FLAGS.use_autograph: outputs = train_inner_loop_fn((data, target), context) else: outputs = ptwse.flow.runner.maybe_run_converted( train_inner_loop_fn, (data, target), context ) xm.master_print(f"Saving model...") _save_checkpoint(FLAGS, device, None, model, is_epoch=True) xm.master_print(f"Model saved") def test_loop_fn(model, loader, device, context): print("***********************") print("ENTERING TEST FUNCTION") print("***********************") print('Evaluating...') total_samples = 0 correct = 0 model.eval() for step, (data, target) in enumerate(loader): if step >= FLAGS.test_max_step: break output = model(data) pred = output.max(1, keepdim=True)[1] if FLAGS.mp: correct += pred.eq(target.view_as(pred)).sum() else: correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] if FLAGS.mp: this_accuracy = 100.0 * correct.item() / total_samples print("CALLING: mesh_reduce('test_accuracy')") this_accuracy = xm.mesh_reduce( 'test_accuracy', this_accuracy, np.mean ) print("BACK FROM: mesh_reduce('test_accuracy')") else: this_accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, this_accuracy) print("***********************") print("LEAVING TEST FUNCTION") print("***********************") return this_accuracy # # Set up for # accuracy = 0.0 num_devices = ( len(xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 ) if not FLAGS.steps_per_epoch: num_training_steps_per_epoch = train_dataset_len // ( FLAGS.batch_size * num_devices ) else: num_training_steps_per_epoch = FLAGS.steps_per_epoch max_accuracy = 0.0 # # Epoch loop # for epoch in range(1, FLAGS.num_epochs + 1): # # Train # device = xm.xla_device() ctx = dp.Context(device=device) ctx.tracker = xm.RateTracker() ctx.step = 0 train_loop_fn(model, train_loader, device, ctx) # # Test # if FLAGS.run_test: with ptwse.scope.proxy_disabled(disabled=FLAGS.test_off_proxy): accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print( 'Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy) ) global_step = (epoch - 1) * num_training_steps_per_epoch max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, global_step, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True, ) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def train_cifar(): print('==> Preparing data..') if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) train_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=True, download=True, transform=transform_train) test_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=False, download=True, transform=transform_test) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, sampler=test_sampler, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Pass [] as device_ids to run using the PyTorch/CPU engine. model = torchvision.models.resnet18 if FLAGS.use_torchvision else ResNet18 model_parallel = dp.DataParallel(model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(accuracies) print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy)) if FLAGS.metrics_debug: print(met.metrics_report()) return accuracy
def train_mnist(): torch.manual_seed(1) if FLAGS.fake_data: train_dataset_len = 60000 # Number of images in MNIST dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(FLAGS.datadir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_dataset_len = len(train_dataset) test_dataset = datasets.MNIST(FLAGS.datadir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, sampler=test_sampler, drop_last=FLAGS.drop_last, shuffle=False, num_workers=FLAGS.num_workers) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Scale learning rate to num cores lr = FLAGS.lr * max(len(devices), 1) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(MNIST, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD( model.parameters(), lr=lr, momentum=FLAGS.momentum)) tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 writer = test_utils.get_summary_writer(FLAGS.logdir) num_devices = len( xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * num_devices) for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) global_step = (epoch - 1) * num_training_steps_per_epoch test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, global_step) if FLAGS.metrics_debug: print(met.metrics_report()) test_utils.close_summary_writer(writer) return accuracy
def train_imagenet(): torch.manual_seed(42) device = xm.xla_device() # model = get_model_property('model_fn')().to(device) model = create_model( FLAGS.model, pretrained=FLAGS.pretrained, num_classes=FLAGS.num_classes, drop_rate=FLAGS.drop, global_pool=FLAGS.gp, bn_tf=FLAGS.bn_tf, bn_momentum=FLAGS.bn_momentum, bn_eps=FLAGS.bn_eps, drop_connect_rate=0.2, checkpoint_path=FLAGS.initial_checkpoint, args = FLAGS).to(device) model_ema=None if FLAGS.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper # import pdb; pdb.set_trace() model_e = create_model( FLAGS.model, pretrained=FLAGS.pretrained, num_classes=FLAGS.num_classes, drop_rate=FLAGS.drop, global_pool=FLAGS.gp, bn_tf=FLAGS.bn_tf, bn_momentum=FLAGS.bn_momentum, bn_eps=FLAGS.bn_eps, drop_connect_rate=0.2, checkpoint_path=FLAGS.initial_checkpoint, args = FLAGS).to(device) model_ema = ModelEma( model_e, decay=FLAGS.model_ema_decay, device='cpu' if FLAGS.model_ema_force_cpu else '', resume=FLAGS.resume) print('==> Preparing data..') img_dim = 224 if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) # else: # normalize = transforms.Normalize( # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # train_dataset = torchvision.datasets.ImageFolder( # os.path.join(FLAGS.data, 'train'), # transforms.Compose([ # transforms.RandomResizedCrop(img_dim), # transforms.RandomHorizontalFlip(), # transforms.ToTensor(), # normalize, # ])) # train_dataset_len = len(train_dataset.imgs) # resize_dim = max(img_dim, 256) # test_dataset = torchvision.datasets.ImageFolder( # os.path.join(FLAGS.data, 'val'), # # Matches Torchvision's eval transforms except Torchvision uses size # # 256 resize for all models both here and in the train loader. Their # # version crashes during training on 299x299 images, e.g. inception. # transforms.Compose([ # transforms.Resize(resize_dim), # transforms.CenterCrop(img_dim), # transforms.ToTensor(), # normalize, # ])) # train_sampler = None # if xm.xrt_world_size() > 1: # train_sampler = torch.utils.data.distributed.DistributedSampler( # train_dataset, # num_replicas=xm.xrt_world_size(), # rank=xm.get_ordinal(), # shuffle=True) # train_loader = torch.utils.data.DataLoader( # train_dataset, # batch_size=FLAGS.batch_size, # sampler=train_sampler, # shuffle=False if train_sampler else True, # num_workers=FLAGS.workers) # test_loader = torch.utils.data.DataLoader( # test_dataset, # batch_size=FLAGS.batch_size, # shuffle=False, # num_workers=FLAGS.workers) else: train_dir = os.path.join(FLAGS.data, 'train') data_config = resolve_data_config(model, FLAGS, verbose=FLAGS.local_rank == 0) dataset_train = Dataset(train_dir) collate_fn = None if not FLAGS.no_prefetcher and FLAGS.mixup > 0: collate_fn = FastCollateMixup(FLAGS.mixup, FLAGS.smoothing, FLAGS.num_classes) train_loader = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=FLAGS.batch_size, is_training=True, use_prefetcher=not FLAGS.no_prefetcher, rand_erase_prob=FLAGS.reprob, rand_erase_mode=FLAGS.remode, interpolation='bicubic', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=FLAGS.workers, distributed=FLAGS.distributed, collate_fn=collate_fn, use_auto_aug=FLAGS.auto_augment, use_mixcut=FLAGS.mixcut, ) eval_dir = os.path.join(FLAGS.data, 'val') train_dataset_len = len(train_loader) if not os.path.isdir(eval_dir): logging.error('Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) test_loader = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size = FLAGS.batch_size, is_training=False, use_prefetcher=FLAGS.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=FLAGS.workers, distributed=FLAGS.distributed, ) writer = None start_epoch = 0 if FLAGS.output and xm.is_master_ordinal(): writer = SummaryWriter(log_dir=FLAGS.output) optimizer = create_optimizer(flags, model) lr_scheduler, num_epochs = create_scheduler(flags, optimizer) if start_epoch > 0: lr_scheduler.step(start_epoch) # optimizer = optim.SGD( # model.parameters(), # lr=FLAGS.lr, # momentum=FLAGS.momentum, # weight_decay=5e-4) num_training_steps_per_epoch = train_dataset_len // ( FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer) train_loss_fn = LabelSmoothingCrossEntropy(smoothing=flags.smoothing) validate_loss_fn = nn.CrossEntropyLoss() # loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = train_loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if model_ema is not None: model_ema.update(model) if lr_scheduler: lr_scheduler.step() if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy def test_loop_fn_ema(loader): total_samples = 0 correct = 0 model_ema.eval() for x, (data, target) in loader: output = model_ema(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.epochs + 1): para_loader = dp.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) para_loader = dp.ParallelLoader(test_loader, [device]) accuracy = test_loop_fn(para_loader.per_device_loader(device)) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) if model_ema is not None: accuracy = test_loop_fn_ema(para_loader.per_device_loader(device)) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy
def train_mnist(flags, training_started=None, dynamic_graph=False, fetch_often=False): torch.manual_seed(1) if flags.fake_data: train_loader = xu.SampleGenerator( data=( torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64), ), sample_count=600000 // flags.batch_size // xm.xrt_world_size(), ) test_loader = xu.SampleGenerator( data=( torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64), ), sample_count=100000 // flags.batch_size // xm.xrt_world_size(), ) else: train_dataset = datasets.MNIST( os.path.join(flags.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ), ) test_dataset = datasets.MNIST( os.path.join(flags.datadir, str(xm.get_ordinal())), train=False, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ), ) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=flags.batch_size, sampler=train_sampler, drop_last=flags.drop_last, shuffle=False if train_sampler else True, num_workers=flags.num_workers, ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=flags.batch_size, drop_last=flags.drop_last, shuffle=False, num_workers=flags.num_workers, ) # Scale learning rate to num cores lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() server = xp.start_server(flags.profiler_port) def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): if dynamic_graph: # testing purpose only: dynamic batch size and graph. index = max(-step, -flags.batch_size + 1) # non-empty data, target = data[:-index, :, :, :], target[:-index] if step >= 15 and training_started: # testing purpose only: set event for synchronization. training_started.set() with xp.StepTrace("train_mnist", step_num=step): with xp.Trace("build_graph"): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) if fetch_often: # testing purpose only: fetch XLA tensors to CPU. loss_i = loss.item() tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer)) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: with xp.StepTrace("test_mnist"): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce("test_accuracy", accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print("Epoch {} train begin {}".format(epoch, test_utils.now())) train_loop_fn(train_device_loader) xm.master_print("Epoch {} train end {}".format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) xm.master_print( "Epoch {} test end {}, Accuracy={:.2f}".format(epoch, test_utils.now(), accuracy) ) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, epoch, dict_to_write={"Accuracy/test": accuracy}, write_xla_metrics=True ) if flags.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print("Max Accuracy: {:.2f}%".format(max_accuracy)) return max_accuracy
def train_mnist(): torch.manual_seed(1) if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=60000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(os.path.join(FLAGS.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) test_dataset = datasets.MNIST(os.path.join(FLAGS.datadir, str(xm.get_ordinal())), train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, drop_last=FLAGS.drop_last, shuffle=False, num_workers=FLAGS.num_workers) # Scale learning rate to num cores lr = FLAGS.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: xm.add_step_closure(_train_update, args=(device, x, loss, tracker)) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 max_accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) xm.master_print('Finished training epoch {}'.format(epoch)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy = test_loop_fn(para_loader.per_device_loader(device)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler, test_sampler = None, None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, persistent_workers=True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, drop_last=FLAGS.drop_last, shuffle=False, persistent_workers=True, num_workers=FLAGS.num_workers) torch.manual_seed(42) device = xm.xla_device() model = get_model_property('model_fn')() # Wrap the model with FSDP # You may wrap all, a subset, or none of the sub-modules with inner FSDPs # - to implement ZeRO-2, wrap none of the sub-modules # - to implement ZeRO-3, wrap all of the sub-modules (nested FSDP) # - you may wrap sub-modules at different granularity (e.g. at each resnet # stage or each residual block or each conv layer). fsdp_wrap = lambda m: FSDP(m.to(device), compute_dtype=getattr(torch, FLAGS.compute_dtype ), fp32_reduce_scatter=FLAGS.fp32_reduce_scatter, flatten_parameters=FLAGS.flatten_parameters) # Apply gradient checkpointing to sub-modules if specified grad_ckpt_wrap = checkpoint_module if FLAGS.use_gradient_checkpointing else ( lambda x: x) if FLAGS.use_nested_fsdp: # Here we apply inner FSDP at the level of child modules for ZeRO-3, which # corresponds to different stages in resnet (i.e. Stage 1 to 5). for submodule_name, submodule in model.named_children(): if sum(p.numel() for p in submodule.parameters()) == 0: # Skip those submodules without parameters (i.e. no need to shard them) continue # Note: wrap with `checkpoint_module` first BEFORE wrapping with FSDP m_fsdp = fsdp_wrap(grad_ckpt_wrap(getattr(model, submodule_name))) setattr(model, submodule_name, m_fsdp) # Always wrap the base model with an outer FSDP model = fsdp_wrap(model) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.WarmupAndExponentialDecayScheduler( optimizer, num_steps_per_epoch=num_training_steps_per_epoch, divide_every_n_epochs=FLAGS.lr_scheduler_divide_every_n_epochs, divisor=FLAGS.lr_scheduler_divisor, num_warmup_epochs=FLAGS.num_warmup_epochs, summary_writer=writer) loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() optimizer.step() # do not reduce gradients on sharded params tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if step % FLAGS.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer)) def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] if step % FLAGS.log_steps == 0: xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step)) accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, FLAGS.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) run_eval = ((not FLAGS.test_only_at_end and epoch % FLAGS.eval_interval == 0) or epoch == FLAGS.num_epochs) if run_eval: accuracy = test_loop_fn(test_device_loader, epoch) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def train_imagenet(): print("==> Preparing data..") img_dim = get_model_property("img_dim") if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64), ), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size(), ) if FLAGS.validate: test_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64), ), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size(), ) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, "train"), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]), ) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) if FLAGS.validate: test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, "val"), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ]), ) train_sampler, test_sampler = None, None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) if FLAGS.validate: test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers, ) if FLAGS.validate: test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, drop_last=FLAGS.drop_last, shuffle=False, num_workers=FLAGS.num_workers, ) device = xm.xla_device() model = get_model_property("model_fn")().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, "lr_scheduler_type", None), scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None), scheduler_divide_every_n_epochs=getattr( FLAGS, "lr_scheduler_divide_every_n_epochs", None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer, ) loss_fn = nn.CrossEntropyLoss() scaler = GradScaler() def train_loop_fn(loader, epoch): if FLAGS.fine_grained_metrics: epoch_start_time = time.time() step_latency_tracker, bwd_latency_tracker, fwd_latency_tracker = [], [], [] else: tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): if FLAGS.fine_grained_metrics: step_start_time = time.time() optimizer.zero_grad() if FLAGS.fine_grained_metrics: fwd_start_time = time.time() with autocast(): output = model(data) loss = loss_fn(output, target) if FLAGS.fine_grained_metrics: fwd_end_time = time.time() fwd_latency = fwd_end_time - fwd_start_time bwd_start_time = time.time() scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() if lr_scheduler: lr_scheduler.step() if FLAGS.fine_grained_metrics: bwd_end_time = time.time() bwd_latency = bwd_end_time - bwd_start_time step_latency = bwd_end_time - step_start_time step_latency_tracker.append(step_latency) bwd_latency_tracker.append(bwd_latency) fwd_latency_tracker.append(fwd_latency) else: tracker.add(FLAGS.batch_size) if step % FLAGS.log_steps == 0: if FLAGS.fine_grained_metrics: print('FineGrainedMetrics :: Epoch={} Step={} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\ epoch, step, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker))) else: # _train_update(device, step, loss, tracker, epoch, writer) xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer)) if FLAGS.fine_grained_metrics: epoch_end_time = time.time() epoch_latency = epoch_end_time - epoch_start_time print('FineGrainedMetrics :: Epoch={} Epoch(s)={:.} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\ epoch, epoch_latency, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker))) def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] if step % FLAGS.log_steps == 0: test_utils.print_test_update(device, None, epoch, step) # xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step)) accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce("test_accuracy", accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) if FLAGS.validate: test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, FLAGS.num_epochs + 1): xm.master_print("Epoch {} train begin {}".format( epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print("Epoch {} train end {}".format(epoch, test_utils.now())) if FLAGS.validate: accuracy = test_loop_fn(test_device_loader, epoch) xm.master_print("Epoch {} test end {}, Accuracy={:.2f}".format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, epoch, dict_to_write={"Accuracy/test": accuracy}, write_xla_metrics=True) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) if FLAGS.validate: xm.master_print("Max Accuracy: {:.2f}%".format(max_accuracy)) return max_accuracy if FLAGS.validate else None
def train_imagenet(): print("==> Preparing data..") img_dim = get_model_property("img_dim") if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64), ), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size(), ) test_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64), ), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size(), ) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, "train"), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]), ) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, "val"), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ]), ) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True, ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers, ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, shuffle=False, num_workers=FLAGS.num_workers, ) torch.manual_seed(42) device = xm.xla_device() model = get_model_property("model_fn")() writer = None if FLAGS.logdir and xm.is_master_ordinal(): writer = SummaryWriter(log_dir=FLAGS.logdir) optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, "lr_scheduler_type", None), scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None), scheduler_divide_every_n_epochs=getattr( FLAGS, "lr_scheduler_divide_every_n_epochs", None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer, ) start_epoch = 0 if FLAGS.warm_start: checkpoint = torch.load(f"./reports/resnet152_model-26.pt") model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler._step_count = checkpoint["step"] start_epoch = checkpoint["epoch"] model.to(device) loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() total_samples = 0 correct = 0 top5_accuracys = 0 losses = 0 for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() losses += loss.item() total_samples += data.size()[0] top5_accuracys += topk_accuracy(output, target, topk=5).item() if lr_scheduler: lr_scheduler.step() if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) return ( losses / (x + 1), (100.0 * correct / total_samples), (top5_accuracys / (x + 1)), ) def test_loop_fn(loader): total_samples = 0 correct = 0 top5_accuracys = 0 model.eval() for x, (data, target) in enumerate(loader): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] top5_accuracys += topk_accuracy(output, target, topk=5).item() accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy, top5_accuracys / (x + 1) accuracy = 0.0 max_accuracy = 0.0 start = time.time() for epoch in range(start_epoch, FLAGS.num_epochs + 1): epoch_start = time.time() para_loader = pl.ParallelLoader(train_loader, [device], loader_prefetch_size=32, device_prefetch_size=8) loss, accuracy, top5_accuracy = train_loop_fn( para_loader.per_device_loader(device)) if xm.is_master_ordinal(): print( "Finished training epoch {}, duration_time {} sec, total duration_time {} sec" .format(epoch, time.time() - epoch_start, time.time() - start)) print( "Epoch: {} (Train), Loss {}, Top-1 Accuracy: {:.2f} Top-5 accuracy: {}" .format(epoch, loss, accuracy, top5_accuracy)) test_utils.add_scalar_to_summary(writer, "Loss/train", loss, epoch) test_utils.add_scalar_to_summary(writer, "Top-1 Accuracy/train", accuracy, epoch) test_utils.add_scalar_to_summary(writer, "Top-5 Accuracy/train", top5_accuracy, epoch) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy, top5_accuracy = test_loop_fn( para_loader.per_device_loader(device)) if xm.is_master_ordinal(): print( "Epoch: {} (Valid), Top-1 Accuracy: {:.2f} Top-5 accuracy: {}". format(epoch, accuracy, top5_accuracy)) test_utils.add_scalar_to_summary(writer, "Top-1 Accuracy/test", accuracy, epoch) test_utils.add_scalar_to_summary(writer, "Top-5 Accuracy/test", top5_accuracy, epoch) if FLAGS.metrics_debug: print(met.metrics_report()) if accuracy > max_accuracy: max_accuracy = max(accuracy, max_accuracy) xm.save( { "epoch": epoch, "step": lr_scheduler._step_count, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), }, f"./reports/{FLAGS.model}_model-{epoch}.pt", master_only=True, ) if writer is not None: writer.flush() return accuracy
def train_mnist(index): torch.manual_seed(1) train_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64), ), sample_count=60000 // FLAGS.batch_size // xm.xrt_world_size(), ) test_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64), ), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size(), ) # Scale learning rate to num cores lr = FLAGS.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST().to(device) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): print("Index %d data %d" % (index, x)) optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update( device, x, loss.item(), tracker.rate(), tracker.global_rate() ) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy = test_loop_fn(para_loader.per_device_loader(device)) if FLAGS.metrics_debug: print(met.metrics_report()) return accuracy
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler, test_sampler = None, None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, drop_last=FLAGS.drop_last, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) device = xm.xla_device() model = get_model_property('model_fn')().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer) loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if step % FLAGS.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer)) def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] if step % FLAGS.log_steps == 0: xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step)) accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, FLAGS.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader, epoch) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def train_imagenet(): print("==> Preparing data..") img_dim = get_model_property("img_dim") if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64), ), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size(), ) test_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64), ), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size(), ) else: normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, "train"), transforms.Compose( [ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ] ), ) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, "val"), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose( [ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ] ), ) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True, ) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False, ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=True, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers, ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, drop_last=False, shuffle=False, num_workers=FLAGS.num_workers, ) torch.manual_seed(42) devices = ( xm.get_xla_supported_devices(max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else ["cpu"] ) print("use tpu devices", devices) # Pass [] as device_ids to run using the PyTorch/CPU engine. torchvision_model = get_model_property("model_fn") model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( "optimizer", lambda: optim.SGD( model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4, ), ) lr_scheduler = context.getattr_or( "lr_scheduler", lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, "lr_scheduler_type", None), scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None), scheduler_divide_every_n_epochs=getattr( FLAGS, "lr_scheduler_divide_every_n_epochs", None ), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer if xm.is_master_ordinal() else None, ), ) tracker = xm.RateTracker() model.train() total_samples = 0 correct = 0 top5_accuracys = 0 losses = 0 for x, (data, target) in loader: optimizer.zero_grad() output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] top5_accuracys += topk_accuracy(output, target, topk=5) loss = loss_fn(output, target) loss.backward() losses += loss.item() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: print( "[{}]({}) Loss={:.5f} Top-1 ACC = {:.2f} Rate={:.2f} GlobalRate={:.2f} Time={}".format( str(device), x, loss.item(), (100.0 * correct / total_samples).item(), tracker.rate(), tracker.global_rate(), time.asctime(), ) ) if lr_scheduler: lr_scheduler.step() return ( losses / (x + 1), (100.0 * correct / total_samples).item(), (top5_accuracys / (x + 1)).item(), ) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 top5_accuracys = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] top5_accuracys += topk_accuracy(output, target, topk=5).item() accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy, top5_accuracys accuracy = 0.0 writer = SummaryWriter(FLAGS.logdir) if FLAGS.logdir else None num_devices = len(xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * num_devices) max_accuracy = 0.0 print("train_loader_len", len(train_loader), num_training_steps_per_epoch) print("test_loader_len", len(test_loader)) for epoch in range(1, FLAGS.num_epochs + 1): global_step = (epoch - 1) * num_training_steps_per_epoch # Train evaluate metrics = model_parallel(train_loop_fn, train_loader) losses, accuracies, top5_accuracys = zip(*metrics) loss = mean(losses) accuracy = mean(accuracies) top5_accuracy = mean(top5_accuracys) print( "Epoch: {} (Train), Loss {}, Mean Top-1 Accuracy: {:.2f} Top-5 accuracy: {}".format( epoch, loss, accuracy, top5_accuracy ) ) test_utils.add_scalar_to_summary(writer, "Loss/train", loss, global_step) test_utils.add_scalar_to_summary( writer, "Top-1 Accuracy/train", accuracy, global_step ) test_utils.add_scalar_to_summary( writer, "Top-5 Accuracy/train", top5_accuracy, global_step ) # Test evaluate metrics = model_parallel(test_loop_fn, test_loader) accuracies, top5_accuracys = zip(*metrics) top5_accuracys = sum(top5_accuracys) top5_accuracy = top5_accuracys / len(test_loader) accuracy = mean(accuracies) print( "Epoch: {} (Valid), Mean Top-1 Accuracy: {:.2f} Top-5 accuracy: {}".format( epoch, accuracy, top5_accuracy ) ) test_utils.add_scalar_to_summary( writer, "Top-1 Accuracy/test", accuracy, global_step ) test_utils.add_scalar_to_summary( writer, "Top-5 Accuracy/test", top5_accuracy, global_step ) if FLAGS.metrics_debug: print(met.metrics_report()) if accuracy > max_accuracy: max_accuracy = max(accuracy, max_accuracy) torch.save( model_parallel.models[0].to("cpu").state_dict(), f"./reports/resnet50_model-{epoch}.pt", ) model_parallel.models[0].to(devices[0]) test_utils.close_summary_writer(writer) print("Max Accuracy: {:.2f}%".format(accuracy)) return max_accuracy
def train_mnist(flags, state_dict): if flags.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=60000 // flags.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=10000 // flags.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(os.path.join(flags.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) test_dataset = datasets.MNIST(os.path.join(flags.datadir, str(xm.get_ordinal())), train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=flags.batch_size, sampler=train_sampler, drop_last=flags.drop_last, shuffle=False if train_sampler else True, num_workers=flags.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=flags.batch_size, drop_last=flags.drop_last, shuffle=False, num_workers=flags.num_workers) # Scale learning rate to num cores lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST() model.load_state_dict(state_dict) model = model.to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer), run_async=FLAGS.async_closures) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] accuracy = 100.0 * correct.item() / total_samples # accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, test_utils.now())) train_loop_fn(train_device_loader) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if flags.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def train_mnist(flags, **kwargs): torch.manual_seed(1) if flags.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=60000 // flags.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=10000 // flags.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST( os.path.join(flags.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) test_dataset = datasets.MNIST( os.path.join(flags.datadir, str(xm.get_ordinal())), train=False, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=flags.batch_size, sampler=train_sampler, drop_last=flags.drop_last, shuffle=False if train_sampler else True, num_workers=flags.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=flags.batch_size, drop_last=flags.drop_last, shuffle=False, num_workers=flags.num_workers) # Scale learning rate to num cores lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST() # Wrap the model with FSDP fsdp_wrap = lambda m: FSDP( m.to(device), compute_dtype=getattr(torch, flags.compute_dtype), fp32_reduce_scatter=flags.fp32_reduce_scatter, flatten_parameters=flags.flatten_parameters) # Apply gradient checkpointing to sub-modules if specified grad_ckpt_wrap = checkpoint_module if flags.use_gradient_checkpointing else ( lambda x: x) if flags.use_nested_fsdp: # Wrap a few sub-modules with inner FSDP (to implement ZeRO-3) # Note: wrap with `checkpoint_module` first BEFORE wrapping with FSDP model.conv1 = fsdp_wrap(grad_ckpt_wrap(model.conv1)) model.conv2 = fsdp_wrap(grad_ckpt_wrap(model.conv2)) model.fc1 = fsdp_wrap(grad_ckpt_wrap(model.fc1)) model.fc2 = fsdp_wrap(grad_ckpt_wrap(model.fc2)) # Always wrap the base model with an outer FSDP model = fsdp_wrap(model) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() def train_loop_fn(model, loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() optimizer.step() # do not reduce gradients on sharded params tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure( _train_update, args=(device, step, loss, tracker, writer), run_async=FLAGS.async_closures) def test_loop_fn(model, loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) train_loop_fn(model, train_device_loader) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(model, test_device_loader) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if flags.metrics_debug: xm.master_print(met.metrics_report()) if flags.ckpt_consolidation: # Note: to run this test, all the model checkpoints needs to be # accessible from the master rank. Set --ckpt_prefix to a shared file # system (e.g. NFS) when running on a TPU pod. # Save the final model checkpoint rank = xm.get_ordinal() world_size = xm.xrt_world_size() ckpt_path = f'{flags.ckpt_prefix}_rank-{rank:08d}-of-{world_size:08d}.pth' ckpt = { 'model': model.state_dict(), 'shard_metadata': model.get_shard_metadata(), 'optimizer': optimizer.state_dict(), # not needed in ckpt consolidation } os.makedirs(os.path.dirname(ckpt_path), exist_ok=True) xm.save(ckpt, ckpt_path, master_only=False) print(f'checkpoint saved to {ckpt_path}\n', end='') # Consolidate the sharded model checkpoints and test its accuracy if xm.is_master_ordinal(local=False): consolidate_sharded_model_checkpoints( ckpt_prefix=flags.ckpt_prefix, ckpt_suffix="_rank-*-of-*.pth") xm.rendezvous('ckpt_consolidation') model = MNIST().to(device) ckpt_consolidated = torch.load(f'{flags.ckpt_prefix}_consolidated.pth') model.load_state_dict(ckpt_consolidated['model']) accuracy = test_loop_fn(model, test_device_loader) xm.master_print( f'Checkpoint consolidated, Accuracy={accuracy:.2f} ' '(note: it can be slightly different from the final training accuracy ' 'due to non-sync BatchNorm2d in the model)') test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy