def test(self): devices = xm.get_xla_supported_devices() batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) train_loader = xu.SampleGenerator( data=(torch.zeros(batch_size, 3, 224, 224), torch.zeros(batch_size, dtype=torch.int64)), sample_count=sample_count * len(devices)) def loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) for x, (data, target) in loader: with xu.TimedScope(msg='Training loop: ', printfn=None): optimizer.zero_grad() output = xu.timed(lambda: model(data), msg='Model: ', printfn=None) loss = xu.timed(lambda: loss_fn(output, target), msg='Loss: ', printfn=None) xu.timed(loss.backward, msg='LossBkw: ', printfn=None) xu.timed(lambda: xm.optimizer_step(optimizer), msg='Step: ', printfn=None) self.assertLess(loss.cpu().item(), 3.0) model_parallel = dp.DataParallel(torchvision.models.resnet18, device_ids=devices) model_parallel(loop_fn, train_loader)
def prepare_task(args, devices): # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=True, epoch=0) # Build models and criteria to print some metadata model_parallel = dp.DataParallel( lambda: task.build_model(args), device_ids=devices) model, criterion = task.build_model(args), task.build_criterion(args) print(model) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) del model, criterion # Build trainers trainers = { device: Trainer(args, task, model, task.build_criterion(args), xla=True) for device, model in zip(model_parallel.devices, model_parallel.models) } trainer = trainers[devices[0]] lr = trainer.get_lr() # TODO(taylanbil): for now, this next line is only creating the iterator. # validate its behavior with the case where a checkpoint actually exists. # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) valid_subsets = args.valid_subset.split(',') return task, trainers, model_parallel, epoch_itr, lr, valid_subsets
np.random.seed(args.seed) torch.manual_seed(args.seed) # load tokenizer tokenizer = AutoTokenizer.from_pretrained(args.bert_model) logging.info(f"Saving tokenizer to: {args.output_dir}") tokenizer.save_pretrained(args.output_dir) # load model model = AutoModelWithLMHead.from_pretrained( args.bert_model) # Only Masked Language Modeling logging.info(f"Saving initial checkpoint to: {args.output_dir}") model.save_pretrained(args.output_dir) # wrap model with TPU stuff model = tpu_dp.DataParallel(model, device_ids=devices) # expected total number of updates total_num_updates = utils.compute_num_updates_in_epoch( num_samples=args.total_num_training_examples, batch_size=args.per_tpu_train_batch_size, grad_accum_steps=args.gradient_accumulation_steps, n_tpu=n_tpu) # expected number of warmup updates if args.warmup_proportion is not None: warmup_updates = int(args.warmup_proportion * total_num_updates) elif args.warmup_steps is not None: warmup_updates = args.warmup_steps else: raise Exception(
def train_cifar(): print('==> Preparing data..') transform_train = transforms.Compose([ transforms.Lambda(lambda x: RandomPixelPad(x, padding=4)), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), Cutout(18, random_pixel=True), # add Cutout transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)), ]) trainset = torchvision.datasets.CIFAR100(root=FLAGS.datadir, train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader(trainset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) testset = torchvision.datasets.CIFAR100(root=FLAGS.datadir, train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader(testset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Define model here model = WRN_McDonnell(20, 10, 100, binarize=True) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) # LR scheduler scheduler = context.getattr_or( 'scheduler', lambda: CosineAnnealingRestartsLR(optimizer, T=2, eta_min=1e-4)) model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) if x % FLAGS.log_steps == 0: print('[{}]({}) Loss={:.5f}'.format(device, x, loss.item())) # Step LR scheduler scheduler.step() def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] return correct / total_samples best_accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(devices) print('Epoch {}, Accuracy={:.2f}%'.format(epoch, 100.0 * accuracy)) # Keep track of best model if accuracy > best_accuracy: best_accuracy = accuracy torch.save(model_parallel._models[0].state_dict(), 'model.pt') if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy * 100.0
def train_mnist(): torch.manual_seed(1) if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=60000 // FLAGS.batch_size) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size) else: train_loader = torch.utils.data.DataLoader( datasets.MNIST( FLAGS.datadir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( datasets.MNIST( FLAGS.datadir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) devices = ( xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Scale learning rate to num cores lr = FLAGS.lr * max(len(devices), 1) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(MNIST, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(device, x, loss.item(), tracker.rate())) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] print('[{}] Accuracy={:.2f}%'.format(device, 100.0 * correct / total_samples)) return correct / total_samples accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(accuracies) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy * 100.0
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=1200000 // FLAGS.batch_size) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size) else: normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = xm.get_xla_supported_devices(max_devices=FLAGS.num_cores) # Pass [] as device_ids to run using the PyTorch/CPU engine. torchvision_model = get_model_property('model_fn') model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = optim.SGD( model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4) tracker = xm.RateTracker() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(device, x, loss.item(), tracker.rate())) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] print('[{}] Accuracy={:.2f}%'.format(device, 100.0 * correct / total_samples)) return correct / total_samples accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(devices) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy * 100.0
def train_cifar(): print('==> Preparing data..') if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) train_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=True, download=True, transform=transform_train) test_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=False, download=True, transform=transform_test) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, sampler=test_sampler, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Pass [] as device_ids to run using the PyTorch/CPU engine. model = torchvision.models.resnet18 if FLAGS.use_torchvision else ResNet18 model_parallel = dp.DataParallel(model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(accuracies) print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy)) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy
def main(): parser = utils.get_args_parser_with_general_args() parser.add_argument( '--one_tpu', action='store_true', help= "Run on one tpu core for degugging. Makes it easy to use break points") parser.add_argument('--tpu_report', action='store_true', help="Print xla metric report") args = parser.parse_args() utils.init(args) # set seeds, init logger, prepare output directory devices = tpu_xm.get_xla_supported_devices() if args.one_tpu: devices = [devices[0]] n_tpu = len(devices) logging.info(f'Found {n_tpu} TPU cores') tokenizer = AutoTokenizer.from_pretrained(args.bert_model) tokenizer.save_pretrained(args.output_dir) args.start_epoch = utils.prepare_last_checkpoint(args.bert_model) model = AutoModelWithLMHead.from_pretrained( args.bert_model) # Only Masked Language Modeling logging.info(f"Saving initial checkpoint to: {args.output_dir}") model.save_pretrained(args.output_dir) model = tpu_dp.DataParallel(model, device_ids=devices) num_data_epochs, num_train_optimization_steps = utils.get_dataset_stats( args, n_tpu) def tpu_training_loop(model, loader, device, context): """ Called by torch_xla_py.data_parallel. This function is executed on each core of the TPU once per epoch""" param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] # one optimizer and scheduler per TPU core. Both objects are saved in `context` to be reused the next epoch optimizer = context.getattr_or( 'optimizer', AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, betas=tuple(args.betas))) # derive warmup info if args.warmup_proportion is not None: warmup_steps = int(args.warmup_proportion * num_train_optimization_steps + 0.5) elif args.warmup_steps is not None: warmup_steps = args.warmup_steps else: raise Exception( 'What is the warmup?? Specify either warmup proportion or steps' ) scheduler = context.getattr_or( 'scheduler', WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps)) tr_loss = None pbar = None if str(pbar_device) == str( device ): # All threads are in sync. Use progress bar only on one of them pbar = tqdm(total=int(pbar_steps), desc=f"device {device}", dynamic_ncols=True) tracker = tpu_xm.RateTracker() model.train() for step, batch in loader: input_ids, input_mask, segment_ids, lm_label_ids, _ = batch outputs = model(input_ids, segment_ids, input_mask, lm_label_ids) loss = outputs[0] if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() tracker.add(args.train_batch_size) tr_loss = loss * args.gradient_accumulation_steps if step == 0 else tr_loss + loss * args.gradient_accumulation_steps if pbar is not None: pbar.update(1) # pbar.set_description(desc=f'LR: {scheduler.get_lr()}') if (step + 1) % args.gradient_accumulation_steps == 0: tpu_xm.optimizer_step(optimizer) prev_lr = scheduler.get_last_lr()[0] scheduler.step() curr_lr = scheduler.get_last_lr()[0] if args.track_learning_rate: if pbar is not None: pbar.set_description( f"Prev LR: {prev_lr} Curr LR: {curr_lr}") optimizer.zero_grad() return tr_loss.item( ) / step # `.item()` requires a trip from TPU to CPU, which is very slow. Use it only once per epoch= for epoch in range(args.start_epoch, args.epochs): # Load one training file into memory epoch_dataset = utils.PregeneratedDataset( epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer, num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory) train_sampler = RandomSampler(epoch_dataset) train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size) pbar_device = devices[0] pbar_steps = utils.compute_num_steps_in_epoch( num_samples=train_sampler.num_samples, batch_size=args.train_batch_size, grad_accum_steps= 1, # the pbar steps should not take into account grad accumulation steps n_tpu=n_tpu) logging.info( f'start training, epoch {epoch} on {len(devices)} cores for {pbar_steps} steps' ) start = time.time() losses = model( tpu_training_loop, train_dataloader ) # calls `tpu_training_loop` multiple times, once per TPU core logging.info( f'Epoch {epoch} took {round(time.time() - start, 2)} seconds. Average loss: {sum(losses)/len(losses)}' ) utils.save_checkpoint(model._models[0], epoch, args.output_dir) if args.tpu_report: logging.info(torch_xla._XLAC._xla_metrics_report())
def train_cifar(): print('==> Preparing data..') if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size) else: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader( trainset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) testset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader( testset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = xm.get_xla_supported_devices(max_devices=FLAGS.num_cores) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(ResNet18, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4) tracker = xm.RateTracker() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format( device, x, loss.item(), tracker.rate())) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] print('[{}] Accuracy={:.2f}%'.format(device, 100.0 * correct / total_samples)) return correct / total_samples accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(devices) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy * 100.0
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=1200000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Pass [] as device_ids to run using the PyTorch/CPU engine. torchvision_model = get_model_property('model_fn') model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy)) test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy
upsample_primary=16, user_dir=None, valid_subset='valid', validate_interval=1, warmup_init_lr=1e-07, warmup_updates=4000, weight_decay=0.0) task = tasks.setup_task(args) task.load_dataset(args.train_subset, combine=True, epoch=0) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=True, epoch=0) #devices = xm.get_xla_supported_devices(max_devices=8) # Got error for max devices argument :( devices = xm.get_xla_supported_devices() model_parallel = dp.DataParallel(lambda: task.build_model(args), device_ids=devices) #max_positions = utils.resolve_max_positions( # task.max_positions(), # model.max_positions(), # # ) max_positions = (1024, 1024 ) # Hardcoded for the moment since the computation requires # model object which will be created by model_parallel __call__ # Re-factor in a cleaner way # Initialize dataloader epoch_itr = task.get_batch_iterator( dataset=task.dataset(args.train_subset), max_tokens=args.max_tokens,
def main(): parser = argparse.ArgumentParser() parser.add_argument("--train_file", default=None, type=str, required=True, help="The train file path") parser.add_argument("--eval_file", default=None, type=str, required=True, help="The dev file path") parser.add_argument("--predict_file", default=None, type=str, required=False, help="The predict file path") parser.add_argument("--predict_result_file", default=None, type=str, required=False, help="The predict result file path") parser.add_argument( "--bert_model", default=None, type=str, required=True, help= "The config json file corresponding to the pre-trained BERT model. \n" "This specifies the model architecture.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help="The output directory where the model checkpoints will be written." ) parser.add_argument( "--init_checkpoint", default=None, type=str, help="Initial checkpoint (usually from a pre-trained BERT model).") parser.add_argument( "--do_lower_case", default=False, action='store_true', help= "Whether to lower case the input text. True for uncased models, False for cased models." ) parser.add_argument( "--max_seq_length", default=300, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.") parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_eval", default=False, action='store_true', help="Whether to run training.") parser.add_argument("--num_labels", default=1, type=int, help="mapping classify nums") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=8, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=6.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumualte before performing a backward/update pass." ) args = parser.parse_args() vocab_path = os.path.join(args.bert_model, VOCAB_NAME) # bert_config = BertConfig.from_json_file(vocab_path) data_processor = DataProcessor() devices = tpu_xm.get_xla_supported_devices() n_tpu = len(devices) logging.info(f'Found {n_tpu} TPU cores') args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.do_train: if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError( "Output directory ({}) already exists and is not empty.". format(args.output_dir)) else: os.makedirs(args.output_dir, exist_ok=True) tokenizer = tokenization.FullTokenizer(vocab_file=vocab_path, do_lower_case=args.do_lower_case) model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=3) for k, v in model.state_dict().items(): print(f'k = {k}, v.grad = {v.grad}') model = tpu_dp.DataParallel(model, device_ids=devices) if args.do_train: # 数据读取 train_examples = data_processor.get_examples(args.train_file, data_type='train') eval_examples = data_processor.get_examples(args.eval_file, data_type='eval') # 特征转换 train_features = convert_examples_to_features(args, train_examples, args.max_seq_length, tokenizer) eval_features = convert_examples_to_features(args, eval_examples, args.max_seq_length, tokenizer) num_train_steps = int( len(train_features) // args.train_batch_size // args.gradient_accumulation_steps * args.num_train_epochs) # 数据loader train_loader = ParaDataloader(train_features) eval_loader = ParaDataloader(eval_features) # 数据并行loader输入格式 train_loader = DataLoader(train_loader, shuffle=True, batch_size=args.train_batch_size) eval_loader = DataLoader(eval_loader, shuffle=False, batch_size=args.eval_batch_size) def tpu_training_loop(model, loader, device, context): """ Called by torch_xla_py.data_parallel. This function is executed on each core of the TPU once per epoch""" model.zero_grad() no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] param_optimizer = list(model.named_parameters()) optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01 }, { 'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0 }] optimizer = context.getattr_or( 'optimizer', BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps)) tr_loss = None pbar = None if str(pbar_device) == str(device): pbar = tqdm(total=int(pbar_steps), desc=f"training", dynamic_ncols=True) tracker = tpu_xm.RateTracker() model.train() for step, batch in enumerate(loader): input_ids, input_mask, segment_ids, label_ids = batch loss, _ = model(input_ids, segment_ids, input_mask, label_ids) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() tracker.add(args.train_batch_size) tr_loss = loss * args.gradient_accumulation_steps if step == 0 else tr_loss + loss * args.gradient_accumulation_steps if pbar is not None: pbar.update(1) tpu_xm.optimizer_step(optimizer) # optimizer.step() optimizer.zero_grad() return tr_loss.item() / step def tpu_evaluating_loop(model, eval_dataloader, device, context): model.eval() eval_loss = 0 eval_pbar = None logits, labels = [], [] if str(pbar_device) == str(device): eval_pbar = tqdm(total=int(eval_pbar_steps), desc=f"evaluating", dynamic_ncols=True) tracker = tpu_xm.RateTracker() for step, batch in enumerate(eval_dataloader): input_ids, input_mask, segment_ids, label_ids = batch with torch.no_grad(): loss, logit = model(input_ids, segment_ids, input_mask, label_ids) eval_loss = loss * args.gradient_accumulation_steps if step == 0 else eval_loss + loss * args.gradient_accumulation_steps logit = torch.argmax(logit, dim=-1) logits.extend(logit.tolist()) labels.extend(label_ids.tolist()) tracker.add(args.eval_batch_size) if eval_pbar is not None: eval_pbar.update(1) return (eval_loss.item() / step, logits, labels) def tpu_predicting_loop(model, dataloader, device, context): model.eval() eval_pbar = None logits, example_ids, probs = [], [], [] if str(pbar_device) == str(device): eval_pbar = tqdm(total=int(eval_pbar_steps), desc=f"evaluating", dynamic_ncols=True) tracker = tpu_xm.RateTracker() for step, batch in enumerate(dataloader): input_ids, input_mask, segment_ids, label_ids = batch with torch.no_grad(): logit = model(input_ids, segment_ids, input_mask) prob = torch.softmax(logit, dim=-1).tolist() logit = torch.argmax(logit, dim=-1) logits.extend(logit.tolist()) example_ids.extend(label_ids.tolist()) probs.extend(prob) tracker.add(args.eval_batch_size) if eval_pbar is not None: eval_pbar.update(1) return logits, example_ids, probs def eval_meric(model, loop, data_loader): eval_results = model(loop, data_loader) eval_loss, eval_loss = 0, 0 all_logits, all_labels = [], [] assert len(eval_results) == len(devices) == 8 for eval_result in eval_results: eval_loss += eval_result[0] all_logits.extend(eval_result[1]) all_labels.extend(eval_result[2]) accuracy(all_labels, all_logits) logger.info(f'Average eval loss = {eval_loss / len(eval_results)}') def write_predict_file(model, loop, data_loader, file_path): """ 写入预测文件: 格式:'五彩滨云-final.csv' """ results = model(loop, data_loader) logits, ids, probs = [], [], [] assert len(results) == len(devices) == 8 for result in results: logits.extend(result[0]) ids.extend(result[1]) probs.extend(result[2]) assert len(ids) == len(logits) logger.info( f'zero nums {logits.count(0)}, one nums {logits.count(1)}, two nums {logits.count(2)}' ) labels = [ data_processor.eval_dict[id][1] for id, logit in zip(ids, logits) ] if not args.do_eval: logits = [i - 1 for i in logits] data_df = pd.DataFrame({'id': ids, 'y': logits}) data_df1 = pd.DataFrame({'id': ids, 'y': logits, 'probs': probs}) data_df1.to_csv('probs_predict.csv', index=None) else: assert len(labels) == len(logits) accuracy(labels, logits) passages = [ data_processor.eval_dict[id][0] for id, logit in zip(ids, logits) ] assert len(labels) == len(passages) match_array = np.array((logits)) == np.array(labels) match_list = match_array.tolist() data_df = pd.DataFrame({ 'id': ids, 'pred': logits, 'real': labels, 'probs': probs, 'match': match_list, 'passage': passages }) data_df.to_csv(file_path, index=None) if args.do_train: for epoch in range(1, int(args.num_train_epochs) + 1, 1): pbar_device = devices[0] logger.info(f'Start to evaluate......') eval_pbar_steps = len(eval_loader) // n_tpu eval_meric(model, tpu_evaluating_loop, eval_loader) pbar_steps = len(train_loader) // n_tpu logging.info( f'Start training, epoch {epoch} on {len(devices)} cores for {pbar_steps} steps' ) start = time.time() losses = model(tpu_training_loop, train_loader) logging.info( f'Epoch {epoch} took {round(time.time() - start, 2)} seconds. average train loss: {sum(losses) / len(losses)}' ) save_checkpoint(model._models[0], epoch, args.output_dir) logger.info('Train finished......') elif args.do_predict: pbar_device = devices[0] logger.info(f'Start to predict......') if args.do_eval: predict_examples = data_processor.get_eval_examples(args.eval_file) else: predict_examples = data_processor.get_predict_examples( args.predict_file) predict_features = convert_examples_to_features( args, predict_examples, args.max_seq_length, tokenizer) predict_loader = ParaDataloader(predict_features) predict_loader = DataLoader(predict_loader, shuffle=False, batch_size=args.eval_batch_size) eval_pbar_steps = len(predict_loader) // n_tpu write_predict_file(model, tpu_predicting_loop, predict_loader, args.predict_result_file)
def train_mnist(): torch.manual_seed(1) # Step 1: init data folders print("init data folders", flush=True) # init character folders for dataset construction metatrain_character_folders, metatest_character_folders = tgtpu.china_drinks_sku_folders( DATASET_FOLDER, SAMPLE_NUM_PER_CLASS, QUERY_NUM_PER_CLASS, VALIDATION_SPLIT_PERCENTAGE) devices = xm.get_xla_supported_devices(max_devices=FLAGS.num_cores) # Scale learning rate to num cores lr = FLAGS.lr * len(devices) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(CNN_Plus_RNEncoder, device_ids=devices) degrees = random.choice([0, 90, 180, 270]) train_task = tgtpu.ChinaDrinksTask(metatrain_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, QUERY_NUM_PER_CLASS) train_sample_batch_dataloader = tgtpu.get_data_loader( train_task, image_size=IMAGE_SIZE, sample_num_per_class=SAMPLE_NUM_PER_CLASS, query_num_per_class=QUERY_NUM_PER_CLASS, train_shuffle=False, query_shuffle=True, rotation=degrees, num_workers=NO_OF_TPU_CORES) test_task = tgtpu.ChinaDrinksTask(metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS) test_sample_test_dataloader = tgtpu.get_data_loader( test_task, IMAGE_SIZE, sample_num_per_class=SAMPLE_NUM_PER_CLASS, query_num_per_class=QUERY_NUM_PER_CLASS, train_shuffle=False, query_shuffle=True, rotation=degrees, num_workers=NO_OF_TPU_CORES) def train_loop_fn(model, loader, device, context): relation_network = model #relation_network.apply(weights_init) relation_network_optim = torch.optim.Adam( relation_network.parameters(), lr=LEARNING_RATE) relation_network_scheduler = StepLR(relation_network_optim, step_size=100000, gamma=0.5) mse = nn.MSELoss() tracker = xm.RateTracker() for x, (samples, sample_labels, batches, batch_labels) in loader: relation_network_scheduler.step(episode) relation_network.zero_grad() #relation_network_optim.zero_grad() relation_scores = relation_network(Variable(samples), Variable(batches)) relations = relation_scores.view(-1, CLASS_NUM) one_hot_labels = Variable( torch.zeros(QUERY_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1, 1), 1)) loss = mse(relations, one_hot_labels) loss.backward() torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5) xm.optimizer_step(relation_network_optim) tracker.add(FLAGS.batch_size) print('Debug: ', x, loss.item()) if x % FLAGS.log_steps == 0: print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format( device, x, loss.item(), tracker.rate())) def test_loop_fn(model, loader, device, context): relation_network = model total_rewards = 0 for x, (samples, sample_labels, batches, batch_labels) in loader: relation_scores = relation_network(Variable(samples), Variable(batches)) relations = relation_scores.view(-1, CLASS_NUM) _, predict_labels = torch.max(relations.data, 1) rewards = [ 1 if predict_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS) ] total_rewards += np.sum(rewards) test_accuracy = total_rewards / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS / TEST_EPISODE print('[{}] Accuracy={:.2f}%'.format(device, 100 * test_accuracy)) return test_accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_sample_batch_dataloader) accuracies = model_parallel(test_loop_fn, test_sample_test_dataloader) accuracy = sum(accuracies) / len(devices) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy * 100.0
k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict } model_dict.update(pretrained_dict) model.load_state_dict(model_dict) print('Set cache dir', flush=True) time = datetime.datetime.now() num_cores = 8 devices = (xm.get_xla_supported_devices( max_devices=num_cores) if num_cores != 0 else []) # Scale learning rate to num cores base_lr = args.base_lr * max(len(devices), 1) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(model, device_ids=devices) # optimizer prepare ignored_params1 = list(map(id, model.classifier.parameters())) ignored_params2 = list(map(id, model.classifier_swap.parameters())) ignored_params3 = list(map(id, model.Convmask.parameters())) ignored_params = ignored_params1 + ignored_params2 + ignored_params3 print('the num of new layers:', len(ignored_params), flush=True) base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) # exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.decay_step, gamma=0.1) # exp_lr_scheduler.step(epoch) # train entry