def __init__(self, network, device_ids=None, batchdim=0, drop_last=False): if device_ids is None: device_ids = xm.get_xla_supported_devices() self._device_ids = list(device_ids) self._batchdim = batchdim self._drop_last = drop_last self._native_run = False if len(self._device_ids) > 1: replication_devices = xm.xla_replication_devices(self._device_ids) self._replication = xm.Replication(self._device_ids, replication_devices) else: self._replication = None self._models = [] self._contexts = [] module = network if isinstance(network, torch.nn.Module) else network() for device in device_ids: device_module = deepcopy(module).to(device=torch.device(device)) self._models.append(device_module) self._contexts.append(Context(torch.device(device))) if not self._models: # No XLA device, push a vanilla network in. device = self._get_model_device(module) self._models.append(module) self._device_ids.append(device) self._contexts.append(Context(torch.device(device))) self._native_run = True
def run_benchmark(args, pos_args): devices = xm.get_xla_supported_devices(max_devices=args.max_devices) shape = [int(x) for x in args.shape.split(',')] send_list = [] for i in range(0, len(devices)): mb = [] for j in range(0, args.prefetch): mb.append(torch.randn(*shape)) send_list.append(mb) def threadfn(i): device = devices[i] xdevices = [device] * len(send_list[i]) for n in range(0, args.test_count): with xu.TimedScope(msg='Send[{}][{}]: '.format(i, n), printfn=print): _ = torch_xla._XLAC._xla_tensors_from_aten( send_list[i], xdevices) threads = [] for i in range(0, len(devices)): t = threading.Thread(target=threadfn, args=(i, )) t.start() threads.append(t) for t in threads: t.join() print(torch_xla._XLAC._xla_metrics_report())
def test(self): devices = xm.get_xla_supported_devices() batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) train_loader = xu.SampleGenerator( data=(torch.zeros(batch_size, 3, 224, 224), torch.zeros(batch_size, dtype=torch.int64)), sample_count=sample_count * len(devices)) def loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) for x, (data, target) in loader: with xu.TimedScope(msg='Training loop: ', printfn=None): optimizer.zero_grad() output = xu.timed(lambda: model(data), msg='Model: ', printfn=None) loss = xu.timed(lambda: loss_fn(output, target), msg='Loss: ', printfn=None) xu.timed(loss.backward, msg='LossBkw: ', printfn=None) xu.timed(lambda: xm.optimizer_step(optimizer), msg='Step: ', printfn=None) self.assertLess(loss.cpu().item(), 3.0) model_parallel = dp.DataParallel(torchvision.models.resnet18, device_ids=devices) model_parallel(loop_fn, train_loader)
def test(self): devices = xm.get_xla_supported_devices() A = 3.11 B = 4.09 batch_size = 128 * len(devices) gen = xu.FnDataGenerator( lambda x: x * A + B, batch_size, _gen_tensor, dims=[8], count=10) para_loader = pl.ParallelLoader(gen, batch_size, devices) for x, (data, target) in para_loader: for device in devices: dx = para_loader.to(data, device) self.assertEqual(dx.device, torch.device(device))
def __init__(self, network, device_ids=None, batchdim=0, drop_last=False): if device_ids is None: device_ids = xm.get_xla_supported_devices() self._batchdim = batchdim self._drop_last = drop_last self._device_ids = list(device_ids) self._replication = (xm.Replication(self._device_ids) if self._device_ids else None) self._models = [] for device in device_ids: module = network().to(device=torch.device(device)) self._models.append(module) if not self._models: # No XLA device, push a vanilla network in. self._models.append(network())
def test(self): devices = xm.get_xla_supported_devices() for device in reversed(devices): t = _gen_tensor(8, 12) tto = t.to(device=torch.device(device)) self.assertEqual(tto.device, torch.device(device)) t = _gen_tensor(8, 12).to(device=torch.device(devices[0])) for device in devices[1:]: tto = t.to(device=torch.device(device)) self.assertEqual(tto.device, torch.device(device)) for i in range(0, len(devices) - 1): dev0 = devices[i] dev1 = devices[i + 1] t0 = torch.zeros(4, 4, device=torch.device(dev0)) t1 = t0.to(device=torch.device(dev1)) t0 = t0 + torch.ones_like(t0, device=torch.device(dev0)) t1 = t1 + torch.ones_like(t1, device=torch.device(dev1)) self.assertEqual(t0.cpu(), t1.cpu())
def __init__(self, network, device_ids=None, batchdim=0, drop_last=False): if device_ids is None: device_ids = xm.get_xla_supported_devices() self._device_ids = list(device_ids) self._batchdim = batchdim self._drop_last = drop_last replication_devices = ( xm.xla_replication_devices(self._device_ids) if self._device_ids else None) self._replication = ( xm.Replication(self._device_ids, replication_devices) if replication_devices else None) self._models = [] module = network if isinstance(network, torch.nn.Module) else network() for device in device_ids: device_module = deepcopy(module).to(device=torch.device(device)) self._models.append(device_module) if not self._models: # No XLA device, push a vanilla network in. self._models.append(network())
def main(): parser = utils.get_args_parser_with_general_args() parser.add_argument( '--one_tpu', action='store_true', help= "Run on one tpu core for degugging. Makes it easy to use break points") parser.add_argument('--tpu_report', action='store_true', help="Print xla metric report") args = parser.parse_args() utils.init(args) # set seeds, init logger, prepare output directory devices = tpu_xm.get_xla_supported_devices() if args.one_tpu: devices = [devices[0]] n_tpu = len(devices) logging.info(f'Found {n_tpu} TPU cores') tokenizer = AutoTokenizer.from_pretrained(args.bert_model) tokenizer.save_pretrained(args.output_dir) args.start_epoch = utils.prepare_last_checkpoint(args.bert_model) model = AutoModelWithLMHead.from_pretrained( args.bert_model) # Only Masked Language Modeling logging.info(f"Saving initial checkpoint to: {args.output_dir}") model.save_pretrained(args.output_dir) model = tpu_dp.DataParallel(model, device_ids=devices) num_data_epochs, num_train_optimization_steps = utils.get_dataset_stats( args, n_tpu) def tpu_training_loop(model, loader, device, context): """ Called by torch_xla_py.data_parallel. This function is executed on each core of the TPU once per epoch""" param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] # one optimizer and scheduler per TPU core. Both objects are saved in `context` to be reused the next epoch optimizer = context.getattr_or( 'optimizer', AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, betas=tuple(args.betas))) # derive warmup info if args.warmup_proportion is not None: warmup_steps = int(args.warmup_proportion * num_train_optimization_steps + 0.5) elif args.warmup_steps is not None: warmup_steps = args.warmup_steps else: raise Exception( 'What is the warmup?? Specify either warmup proportion or steps' ) scheduler = context.getattr_or( 'scheduler', WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps)) tr_loss = None pbar = None if str(pbar_device) == str( device ): # All threads are in sync. Use progress bar only on one of them pbar = tqdm(total=int(pbar_steps), desc=f"device {device}", dynamic_ncols=True) tracker = tpu_xm.RateTracker() model.train() for step, batch in loader: input_ids, input_mask, segment_ids, lm_label_ids, _ = batch outputs = model(input_ids, segment_ids, input_mask, lm_label_ids) loss = outputs[0] if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() tracker.add(args.train_batch_size) tr_loss = loss * args.gradient_accumulation_steps if step == 0 else tr_loss + loss * args.gradient_accumulation_steps if pbar is not None: pbar.update(1) # pbar.set_description(desc=f'LR: {scheduler.get_lr()}') if (step + 1) % args.gradient_accumulation_steps == 0: tpu_xm.optimizer_step(optimizer) prev_lr = scheduler.get_last_lr()[0] scheduler.step() curr_lr = scheduler.get_last_lr()[0] if args.track_learning_rate: if pbar is not None: pbar.set_description( f"Prev LR: {prev_lr} Curr LR: {curr_lr}") optimizer.zero_grad() return tr_loss.item( ) / step # `.item()` requires a trip from TPU to CPU, which is very slow. Use it only once per epoch= for epoch in range(args.start_epoch, args.epochs): # Load one training file into memory epoch_dataset = utils.PregeneratedDataset( epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer, num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory) train_sampler = RandomSampler(epoch_dataset) train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size) pbar_device = devices[0] pbar_steps = utils.compute_num_steps_in_epoch( num_samples=train_sampler.num_samples, batch_size=args.train_batch_size, grad_accum_steps= 1, # the pbar steps should not take into account grad accumulation steps n_tpu=n_tpu) logging.info( f'start training, epoch {epoch} on {len(devices)} cores for {pbar_steps} steps' ) start = time.time() losses = model( tpu_training_loop, train_dataloader ) # calls `tpu_training_loop` multiple times, once per TPU core logging.info( f'Epoch {epoch} took {round(time.time() - start, 2)} seconds. Average loss: {sum(losses)/len(losses)}' ) utils.save_checkpoint(model._models[0], epoch, args.output_dir) if args.tpu_report: logging.info(torch_xla._XLAC._xla_metrics_report())
def train_mnist(): torch.manual_seed(1) # Step 1: init data folders print("init data folders", flush=True) # init character folders for dataset construction metatrain_character_folders, metatest_character_folders = tgtpu.china_drinks_sku_folders( DATASET_FOLDER, SAMPLE_NUM_PER_CLASS, QUERY_NUM_PER_CLASS, VALIDATION_SPLIT_PERCENTAGE) devices = xm.get_xla_supported_devices(max_devices=FLAGS.num_cores) # Scale learning rate to num cores lr = FLAGS.lr * len(devices) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(CNN_Plus_RNEncoder, device_ids=devices) degrees = random.choice([0, 90, 180, 270]) train_task = tgtpu.ChinaDrinksTask(metatrain_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, QUERY_NUM_PER_CLASS) train_sample_batch_dataloader = tgtpu.get_data_loader( train_task, image_size=IMAGE_SIZE, sample_num_per_class=SAMPLE_NUM_PER_CLASS, query_num_per_class=QUERY_NUM_PER_CLASS, train_shuffle=False, query_shuffle=True, rotation=degrees, num_workers=NO_OF_TPU_CORES) test_task = tgtpu.ChinaDrinksTask(metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS) test_sample_test_dataloader = tgtpu.get_data_loader( test_task, IMAGE_SIZE, sample_num_per_class=SAMPLE_NUM_PER_CLASS, query_num_per_class=QUERY_NUM_PER_CLASS, train_shuffle=False, query_shuffle=True, rotation=degrees, num_workers=NO_OF_TPU_CORES) def train_loop_fn(model, loader, device, context): relation_network = model #relation_network.apply(weights_init) relation_network_optim = torch.optim.Adam( relation_network.parameters(), lr=LEARNING_RATE) relation_network_scheduler = StepLR(relation_network_optim, step_size=100000, gamma=0.5) mse = nn.MSELoss() tracker = xm.RateTracker() for x, (samples, sample_labels, batches, batch_labels) in loader: relation_network_scheduler.step(episode) relation_network.zero_grad() #relation_network_optim.zero_grad() relation_scores = relation_network(Variable(samples), Variable(batches)) relations = relation_scores.view(-1, CLASS_NUM) one_hot_labels = Variable( torch.zeros(QUERY_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1, 1), 1)) loss = mse(relations, one_hot_labels) loss.backward() torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5) xm.optimizer_step(relation_network_optim) tracker.add(FLAGS.batch_size) print('Debug: ', x, loss.item()) if x % FLAGS.log_steps == 0: print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format( device, x, loss.item(), tracker.rate())) def test_loop_fn(model, loader, device, context): relation_network = model total_rewards = 0 for x, (samples, sample_labels, batches, batch_labels) in loader: relation_scores = relation_network(Variable(samples), Variable(batches)) relations = relation_scores.view(-1, CLASS_NUM) _, predict_labels = torch.max(relations.data, 1) rewards = [ 1 if predict_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS) ] total_rewards += np.sum(rewards) test_accuracy = total_rewards / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS / TEST_EPISODE print('[{}] Accuracy={:.2f}%'.format(device, 100 * test_accuracy)) return test_accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_sample_batch_dataloader) accuracies = model_parallel(test_loop_fn, test_sample_test_dataloader) accuracy = sum(accuracies) / len(devices) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy * 100.0
logging.warning(f'This will get logged to file: {args.log_file}') else: logging.basicConfig(level=logging.INFO, format=log_format) # create output dir if os.path.exists(args.output_dir): y_or_n = input( f'Output Dir {args.output_dir} already exists. Write to same dir? (y/n)' ) if y_or_n != 'y': raise Exception('Set new output dir') else: os.makedirs(args.output_dir, exist_ok=True) # TPU devices devices = tpu_xm.get_xla_supported_devices() if args.one_tpu: devices = [devices[0]] n_tpu = len(devices) logging.info(f'Found {n_tpu} TPU cores') # set seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) # load tokenizer tokenizer = AutoTokenizer.from_pretrained(args.bert_model) logging.info(f"Saving tokenizer to: {args.output_dir}") tokenizer.save_pretrained(args.output_dir)
def train_cifar(): print('==> Preparing data..') transform_train = transforms.Compose([ transforms.Lambda(lambda x: RandomPixelPad(x, padding=4)), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), Cutout(18, random_pixel=True), # add Cutout transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)), ]) trainset = torchvision.datasets.CIFAR100(root=FLAGS.datadir, train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader(trainset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) testset = torchvision.datasets.CIFAR100(root=FLAGS.datadir, train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader(testset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Define model here model = WRN_McDonnell(20, 10, 100, binarize=True) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) # LR scheduler scheduler = context.getattr_or( 'scheduler', lambda: CosineAnnealingRestartsLR(optimizer, T=2, eta_min=1e-4)) model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) if x % FLAGS.log_steps == 0: print('[{}]({}) Loss={:.5f}'.format(device, x, loss.item())) # Step LR scheduler scheduler.step() def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] return correct / total_samples best_accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(devices) print('Epoch {}, Accuracy={:.2f}%'.format(epoch, 100.0 * accuracy)) # Keep track of best model if accuracy > best_accuracy: best_accuracy = accuracy torch.save(model_parallel._models[0].state_dict(), 'model.pt') if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy * 100.0
def main(): parser = argparse.ArgumentParser() parser.add_argument("--train_file", default=None, type=str, required=True, help="The train file path") parser.add_argument("--eval_file", default=None, type=str, required=True, help="The dev file path") parser.add_argument("--predict_file", default=None, type=str, required=False, help="The predict file path") parser.add_argument("--predict_result_file", default=None, type=str, required=False, help="The predict result file path") parser.add_argument( "--bert_model", default=None, type=str, required=True, help= "The config json file corresponding to the pre-trained BERT model. \n" "This specifies the model architecture.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help="The output directory where the model checkpoints will be written." ) parser.add_argument( "--init_checkpoint", default=None, type=str, help="Initial checkpoint (usually from a pre-trained BERT model).") parser.add_argument( "--do_lower_case", default=False, action='store_true', help= "Whether to lower case the input text. True for uncased models, False for cased models." ) parser.add_argument( "--max_seq_length", default=300, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.") parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_eval", default=False, action='store_true', help="Whether to run training.") parser.add_argument("--num_labels", default=1, type=int, help="mapping classify nums") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=8, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=6.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumualte before performing a backward/update pass." ) args = parser.parse_args() vocab_path = os.path.join(args.bert_model, VOCAB_NAME) # bert_config = BertConfig.from_json_file(vocab_path) data_processor = DataProcessor() devices = tpu_xm.get_xla_supported_devices() n_tpu = len(devices) logging.info(f'Found {n_tpu} TPU cores') args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.do_train: if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError( "Output directory ({}) already exists and is not empty.". format(args.output_dir)) else: os.makedirs(args.output_dir, exist_ok=True) tokenizer = tokenization.FullTokenizer(vocab_file=vocab_path, do_lower_case=args.do_lower_case) model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=3) for k, v in model.state_dict().items(): print(f'k = {k}, v.grad = {v.grad}') model = tpu_dp.DataParallel(model, device_ids=devices) if args.do_train: # 数据读取 train_examples = data_processor.get_examples(args.train_file, data_type='train') eval_examples = data_processor.get_examples(args.eval_file, data_type='eval') # 特征转换 train_features = convert_examples_to_features(args, train_examples, args.max_seq_length, tokenizer) eval_features = convert_examples_to_features(args, eval_examples, args.max_seq_length, tokenizer) num_train_steps = int( len(train_features) // args.train_batch_size // args.gradient_accumulation_steps * args.num_train_epochs) # 数据loader train_loader = ParaDataloader(train_features) eval_loader = ParaDataloader(eval_features) # 数据并行loader输入格式 train_loader = DataLoader(train_loader, shuffle=True, batch_size=args.train_batch_size) eval_loader = DataLoader(eval_loader, shuffle=False, batch_size=args.eval_batch_size) def tpu_training_loop(model, loader, device, context): """ Called by torch_xla_py.data_parallel. This function is executed on each core of the TPU once per epoch""" model.zero_grad() no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] param_optimizer = list(model.named_parameters()) optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01 }, { 'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0 }] optimizer = context.getattr_or( 'optimizer', BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps)) tr_loss = None pbar = None if str(pbar_device) == str(device): pbar = tqdm(total=int(pbar_steps), desc=f"training", dynamic_ncols=True) tracker = tpu_xm.RateTracker() model.train() for step, batch in enumerate(loader): input_ids, input_mask, segment_ids, label_ids = batch loss, _ = model(input_ids, segment_ids, input_mask, label_ids) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() tracker.add(args.train_batch_size) tr_loss = loss * args.gradient_accumulation_steps if step == 0 else tr_loss + loss * args.gradient_accumulation_steps if pbar is not None: pbar.update(1) tpu_xm.optimizer_step(optimizer) # optimizer.step() optimizer.zero_grad() return tr_loss.item() / step def tpu_evaluating_loop(model, eval_dataloader, device, context): model.eval() eval_loss = 0 eval_pbar = None logits, labels = [], [] if str(pbar_device) == str(device): eval_pbar = tqdm(total=int(eval_pbar_steps), desc=f"evaluating", dynamic_ncols=True) tracker = tpu_xm.RateTracker() for step, batch in enumerate(eval_dataloader): input_ids, input_mask, segment_ids, label_ids = batch with torch.no_grad(): loss, logit = model(input_ids, segment_ids, input_mask, label_ids) eval_loss = loss * args.gradient_accumulation_steps if step == 0 else eval_loss + loss * args.gradient_accumulation_steps logit = torch.argmax(logit, dim=-1) logits.extend(logit.tolist()) labels.extend(label_ids.tolist()) tracker.add(args.eval_batch_size) if eval_pbar is not None: eval_pbar.update(1) return (eval_loss.item() / step, logits, labels) def tpu_predicting_loop(model, dataloader, device, context): model.eval() eval_pbar = None logits, example_ids, probs = [], [], [] if str(pbar_device) == str(device): eval_pbar = tqdm(total=int(eval_pbar_steps), desc=f"evaluating", dynamic_ncols=True) tracker = tpu_xm.RateTracker() for step, batch in enumerate(dataloader): input_ids, input_mask, segment_ids, label_ids = batch with torch.no_grad(): logit = model(input_ids, segment_ids, input_mask) prob = torch.softmax(logit, dim=-1).tolist() logit = torch.argmax(logit, dim=-1) logits.extend(logit.tolist()) example_ids.extend(label_ids.tolist()) probs.extend(prob) tracker.add(args.eval_batch_size) if eval_pbar is not None: eval_pbar.update(1) return logits, example_ids, probs def eval_meric(model, loop, data_loader): eval_results = model(loop, data_loader) eval_loss, eval_loss = 0, 0 all_logits, all_labels = [], [] assert len(eval_results) == len(devices) == 8 for eval_result in eval_results: eval_loss += eval_result[0] all_logits.extend(eval_result[1]) all_labels.extend(eval_result[2]) accuracy(all_labels, all_logits) logger.info(f'Average eval loss = {eval_loss / len(eval_results)}') def write_predict_file(model, loop, data_loader, file_path): """ 写入预测文件: 格式:'五彩滨云-final.csv' """ results = model(loop, data_loader) logits, ids, probs = [], [], [] assert len(results) == len(devices) == 8 for result in results: logits.extend(result[0]) ids.extend(result[1]) probs.extend(result[2]) assert len(ids) == len(logits) logger.info( f'zero nums {logits.count(0)}, one nums {logits.count(1)}, two nums {logits.count(2)}' ) labels = [ data_processor.eval_dict[id][1] for id, logit in zip(ids, logits) ] if not args.do_eval: logits = [i - 1 for i in logits] data_df = pd.DataFrame({'id': ids, 'y': logits}) data_df1 = pd.DataFrame({'id': ids, 'y': logits, 'probs': probs}) data_df1.to_csv('probs_predict.csv', index=None) else: assert len(labels) == len(logits) accuracy(labels, logits) passages = [ data_processor.eval_dict[id][0] for id, logit in zip(ids, logits) ] assert len(labels) == len(passages) match_array = np.array((logits)) == np.array(labels) match_list = match_array.tolist() data_df = pd.DataFrame({ 'id': ids, 'pred': logits, 'real': labels, 'probs': probs, 'match': match_list, 'passage': passages }) data_df.to_csv(file_path, index=None) if args.do_train: for epoch in range(1, int(args.num_train_epochs) + 1, 1): pbar_device = devices[0] logger.info(f'Start to evaluate......') eval_pbar_steps = len(eval_loader) // n_tpu eval_meric(model, tpu_evaluating_loop, eval_loader) pbar_steps = len(train_loader) // n_tpu logging.info( f'Start training, epoch {epoch} on {len(devices)} cores for {pbar_steps} steps' ) start = time.time() losses = model(tpu_training_loop, train_loader) logging.info( f'Epoch {epoch} took {round(time.time() - start, 2)} seconds. average train loss: {sum(losses) / len(losses)}' ) save_checkpoint(model._models[0], epoch, args.output_dir) logger.info('Train finished......') elif args.do_predict: pbar_device = devices[0] logger.info(f'Start to predict......') if args.do_eval: predict_examples = data_processor.get_eval_examples(args.eval_file) else: predict_examples = data_processor.get_predict_examples( args.predict_file) predict_features = convert_examples_to_features( args, predict_examples, args.max_seq_length, tokenizer) predict_loader = ParaDataloader(predict_features) predict_loader = DataLoader(predict_loader, shuffle=False, batch_size=args.eval_batch_size) eval_pbar_steps = len(predict_loader) // n_tpu write_predict_file(model, tpu_predicting_loop, predict_loader, args.predict_result_file)
def train_mnist(): torch.manual_seed(1) if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=60000 // FLAGS.batch_size) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size) else: train_loader = torch.utils.data.DataLoader( datasets.MNIST( FLAGS.datadir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( datasets.MNIST( FLAGS.datadir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) devices = ( xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Scale learning rate to num cores lr = FLAGS.lr * max(len(devices), 1) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(MNIST, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(device, x, loss.item(), tracker.rate())) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] print('[{}] Accuracy={:.2f}%'.format(device, 100.0 * correct / total_samples)) return correct / total_samples accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(accuracies) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy * 100.0
update_freq=[1], upsample_primary=16, user_dir=None, valid_subset='valid', validate_interval=1, warmup_init_lr=1e-07, warmup_updates=4000, weight_decay=0.0) task = tasks.setup_task(args) task.load_dataset(args.train_subset, combine=True, epoch=0) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=True, epoch=0) #devices = xm.get_xla_supported_devices(max_devices=8) # Got error for max devices argument :( devices = xm.get_xla_supported_devices() model_parallel = dp.DataParallel(lambda: task.build_model(args), device_ids=devices) #max_positions = utils.resolve_max_positions( # task.max_positions(), # model.max_positions(), # # ) max_positions = (1024, 1024 ) # Hardcoded for the moment since the computation requires # model object which will be created by model_parallel __call__ # Re-factor in a cleaner way # Initialize dataloader epoch_itr = task.get_batch_iterator(
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=1200000 // FLAGS.batch_size) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size) else: normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = xm.get_xla_supported_devices(max_devices=FLAGS.num_cores) # Pass [] as device_ids to run using the PyTorch/CPU engine. torchvision_model = get_model_property('model_fn') model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = optim.SGD( model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4) tracker = xm.RateTracker() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(device, x, loss.item(), tracker.rate())) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] print('[{}] Accuracy={:.2f}%'.format(device, 100.0 * correct / total_samples)) return correct / total_samples accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(devices) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy * 100.0
raise Exception("no checkpoints to load") model_dict = model.state_dict() pretrained_dict = torch.load(resume) pretrained_dict = { k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict } model_dict.update(pretrained_dict) model.load_state_dict(model_dict) print('Set cache dir', flush=True) time = datetime.datetime.now() num_cores = 8 devices = (xm.get_xla_supported_devices( max_devices=num_cores) if num_cores != 0 else []) # Scale learning rate to num cores base_lr = args.base_lr * max(len(devices), 1) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(model, device_ids=devices) # optimizer prepare ignored_params1 = list(map(id, model.classifier.parameters())) ignored_params2 = list(map(id, model.classifier_swap.parameters())) ignored_params3 = list(map(id, model.Convmask.parameters())) ignored_params = ignored_params1 + ignored_params2 + ignored_params3 print('the num of new layers:', len(ignored_params), flush=True) base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=1200000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Pass [] as device_ids to run using the PyTorch/CPU engine. torchvision_model = get_model_property('model_fn') model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy)) test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy
def train_cifar(): print('==> Preparing data..') if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size) else: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader( trainset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) testset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader( testset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = xm.get_xla_supported_devices(max_devices=FLAGS.num_cores) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(ResNet18, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4) tracker = xm.RateTracker() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format( device, x, loss.item(), tracker.rate())) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] print('[{}] Accuracy={:.2f}%'.format(device, 100.0 * correct / total_samples)) return correct / total_samples accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(devices) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy * 100.0
def main_tpu(args): def log_step(step_type, device, step, tracker=None, metrics_debug=False): msg = '{}/ {}, device {}, step {}'.format(step_type, utils.now(), device, step) if tracker: rates = tracker.rate(), tracker.global_rate() msg += ', Rate={:.2f}, Global Rate={:.2f}'.format(*rates) return msg def train_loop_fn(model, loader, device, context): trainer = trainers[str(device)] stats = None tracker = xm.RateTracker() for i, samples in loader: if i and not (i % args.log_steps): print( log_step( 'training', device, i, tracker=tracker, metrics_debug=args.metrics_debug)) _log_output = trainer.train_step(samples) xm.optimizer_step(trainer.optimizer) tracker.add(len(samples) * args.max_sentences) # n_batches * batch_size stats = fairseq_train.get_training_stats(trainer) return tracker, stats def valid_loop_fn(model, loader, device, context): trainer = trainers[str(device)] # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for i, sample in loader: if not (i % args.log_steps): print( log_step( 'validation', device, i, tracker=None, metrics_debug=args.metrics_debug)) log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: continue extra_meters[k].update(v) stats = fairseq_train.get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg return stats def validate_subset(args, trainers, task, epoch_itr, subset): print('Validating the subset "{}"'.format(subset)) # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), list(trainers.values())[0].get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_workers=args.num_workers).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') stats_per_device = model_parallel(valid_loop_fn, progress) valid_losses = [stats['loss'].avg for stats in stats_per_device] print('validation stats on subset "{}" - {}'.format(subset, utils.now())) for stats in stats_per_device: progress.print(stats, tag=subset, step=trainer.get_num_updates()) return valid_losses def validate(args, trainers, task, epoch_itr, subsets): valid_losses = { subset: validate_subset(args, trainers, task, epoch_itr, subset) for subset in subsets } return valid_losses def initialize_loader_for_epoch(args, epoch_itr): if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=False, shuffle=(epoch_itr.epoch >= args.curriculum)) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple') return progress def keep_training(lr, epoch_itr, trainers): # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = min(trainer.get_lr() for trainer in trainers.values()) n_updates = max(trainer.get_num_updates() for trainer in trainers.values()) return ((lr > FLAGS.min_lr) and (epoch_itr.epoch < max_epoch) and (n_updates < max_update)) xu.eprint('Args') for key, val in args.__dict__.items(): xu.eprint('\t{} {}'.format(key, val)) xu.eprint('---------') devices = xm.get_xla_supported_devices(max_devices=args.num_cores) task, trainers, model_parallel, epoch_itr, lr, valid_subsets = prepare_task( args, devices) train_meter = StopwatchMeter() train_meter.start() while keep_training(lr, epoch_itr, trainers): # TRAINING print('Epoch {} begin {}'.format(epoch_itr.epoch + 1, utils.now())) progress = initialize_loader_for_epoch(args, epoch_itr) out = model_parallel(train_loop_fn, progress) trackers, stats_ = zip(*out) print('Epoch {} Training stats:'.format(epoch_itr.epoch)) for device, trainer in trainers.items(): stats = fairseq_train.get_training_stats(trainer) print('device {}'.format(device)) progress.print(stats, tag=device) print('Epoch {} Tracker Rates:'.format(epoch_itr.epoch)) for tracker in trackers: rates = tracker.rate(), tracker.global_rate() print('\tRate={:.2f}, Global Rate={:.2f}'.format(*rates)) print('Epoch {} end {}'.format(epoch_itr.epoch, utils.now())) if args.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) # VALIDATION if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate(args, trainers, task, epoch_itr, valid_subsets) # only use average first validation loss from the first device # to update the learning rate vloss = valid_losses[valid_subsets[0]][0] print('old learning rate: {}'.format(lr)) lr = trainers[devices[0]].lr_step(epoch_itr.epoch, vloss) print('new learning rate: {}'.format(lr)) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, vloss) if args.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum))
def test_get_real_xla_devices(self): devices = xm.get_xla_supported_devices() xla_devices = torch_xla._XLAC._xla_real_devices(devices) for device, xdevice in zip(devices, xla_devices): self.assertTrue( re.match(r'(CPU|GPU|TPU):\d+$', xdevice) is not None)
def train_cifar(): print('==> Preparing data..') if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) train_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=True, download=True, transform=transform_train) test_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=False, download=True, transform=transform_test) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, sampler=test_sampler, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Pass [] as device_ids to run using the PyTorch/CPU engine. model = torchvision.models.resnet18 if FLAGS.use_torchvision else ResNet18 model_parallel = dp.DataParallel(model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(accuracies) print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy)) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy