def test(self): devices = xm.get_xla_supported_devices() batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) train_loader = xu.SampleGenerator( data=(torch.zeros(batch_size, 3, 224, 224), torch.zeros(batch_size, dtype=torch.int64)), sample_count=sample_count * len(devices)) def loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) for data, target in loader: with xu.TimedScope(msg='Training loop: ', printfn=None): optimizer.zero_grad() output = xu.timed(lambda: model(data), msg='Model: ', printfn=None) loss = xu.timed( lambda: loss_fn(output, target), msg='Loss: ', printfn=None) xu.timed(loss.backward, msg='LossBkw: ', printfn=None) xu.timed( lambda: xm.optimizer_step(optimizer), msg='Step: ', printfn=None) self.assertLess(loss.cpu().item(), 3.0) model_parallel = dp.DataParallel( torchvision.models.resnet18, device_ids=devices) model_parallel(loop_fn, train_loader)
def inference(): model = model_from_name(args.model_name) devices = (xm.get_xla_supported_devices( max_devices=args.num_cores) if args.num_cores != 0 else []) model.load_state_dict(torch.load(args.weight_file)) model_parallel = dp.DataParallel(model, device_ids=devices) patient = extract_patient(args.csv_file_path, return_label=False) ## For all predictions files must be multiple by num_cores*batch_size extra = args.num_cores * args.batch_size - len( patient) % args.num_cores * args.batch_size patient = np.concatenate([patient, patient[:extra]]) ds = RSNADataset(patient, path=args.path, transform=train_transforms) loader = D.DataLoader(ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) result = model_parallel(infer_loop, loader) result = np.array(result) score, patientid = [result[:, i] for i in range(result.shape[-1])] score = np.concatenate(score) patientid = np.concatenate(patientid) prediction_to_df(score, patientid, args.subm_file)
def train(): set_seeds() logging.info('Loading masks...') with open(args.json_file, 'r') as f: masks = json.load(f) # for example use only 200 images filename = list(masks.keys())[:200] global devices, num_steps_per_epoch devices = (xm.get_xla_supported_devices( max_devices=args.num_cores) if args.num_cores != 0 else []) logging.info('Start training model') if args.model_name == 'deeplabv3_resnet50': m = torchvision.models.segmentation.deeplabv3_resnet50(False) else: m = torchvision.models.segmentation.fcn_resnet50(False) m.classifier[-1] = torch.nn.Conv2d(m.classifier[-1].in_channels, 46, 1) # wrapped for parallel training model = dp.DataParallel(m, device_ids=devices) ds = FashionDataset(filename, masks, path=args.data_path, transform=train_transform, size=(256, 256)) loader = D.DataLoader(ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_worker) num_steps_per_epoch = len(loader) // len(devices) for epoch in range(1, args.epochs + 1): train_loss = model(train_loop_fn, loader) train_loss = np.array(train_loss).mean() logging.info('[Epoch {:3d}] Train loss: {:.3f}'.format( epoch, train_loss)) # Save weights state_dict = model.models[0].to('cpu').state_dict() torch.save(state_dict, args.save_file) logging.info('') logging.info('Model saved\n')
def train(): set_seeds() global num_steps_per_epoch ## Create and wrap model ## model = model_from_name(args.model_name) devices = ( xm.get_xla_supported_devices(max_devices=args.num_cores) if args.num_cores != 0 else []) model_parallel = dp.DataParallel(model, device_ids=devices) logging.info('Model {} loaded and wrapped'.format(args.model_name)) logging.info('') ## Create dataset and loader patient, label = extract_patient(args.csv_file_path, return_label=True) if args.use_image: ds = RSNAImages(patient, label, path=args.path, transform=train_transforms) else: ds = RSNADataset(patient, label, path=args.path, transform=train_transforms) loader = D.DataLoader(ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) logging.info('Dataset created\n') logging.info('Start training model\n') num_steps_per_epoch = len(loader) // len(devices) for epoch in range(1, args.num_epochs + 1): start_time = time.time() model_parallel(train_loop_fn, loader) logging.info('') logging.info('Epoch training time: {:.2f} minutes\n'.format((time.time() - start_time)/60**1)) # Save weights state_dict = model_parallel.models[0].to('cpu').state_dict() torch.save(state_dict, args.save_pht) logging.info('') logging.info('Model saved') logging.info('')
def main(model, dataset, train_pairs, qrels, valid_run, qrelf, model_out_dir): params = [(k, v) for k, v in model.named_parameters() if v.requires_grad] non_bert_params = { 'params': [v for k, v in params if not k.startswith('bert.')] } bert_params = { 'params': [v for k, v in params if k.startswith('bert.')], 'lr': BERT_LR } optimizer = torch.optim.Adam([non_bert_params, bert_params], lr=LR) # optimizer = torch.optim.SGD([non_bert_params, bert_params], lr=LR, momentum=0.9) # model.to(device) model_parallel = dp.DataParallel(model, device_ids=devices) epoch = 0 top_valid_score = None for epoch in range(MAX_EPOCH): # loss = train_iteration(model, optimizer, dataset, train_pairs, qrels) # print(f'train epoch={epoch} loss={loss}') # # return train_set = TrainDataset(it=data.iter_train_pairs( model, dataset, train_pairs, qrels, 1), length=BATCH_SIZE * BATCHES_PER_EPOCH) train_loader = torch.utils.data.DataLoader( train_set, batch_size=GRAD_ACC_SIZE, ) # for i, tr in enumerate(train_loader): # for tt in tr: # print(tt, tr[tt].size()) # break # print('finished') # return model_parallel(train_iteration_multi, train_loader) '''
def train_mnist(FLAGS): DTYPE = torch.float32 torch.manual_seed(1) dims = ( FLAGS.batch_size, 1, 784, ) train_dataset_len = FLAGS.steps_per_epoch if FLAGS.steps_per_epoch else 60000 train_loader = xu.SampleGenerator( data=( torch.ones(dims, dtype=DTYPE,), torch.ones( FLAGS.batch_size, dtype=torch.int64 if not _MSE_LOSS else DTYPE, ), ), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size(), ) test_loader = xu.SampleGenerator( data=( torch.ones(dims, dtype=DTYPE,), torch.ones( FLAGS.batch_size, dtype=torch.int64 if not _MSE_LOSS else DTYPE, ), ), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size(), ) devices = ( xm.get_xla_supported_devices(max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [] ) """ Non multi-processing """ # Scale learning rate to num cores lr = FLAGS.lr * max(len(devices), 1) model = MNIST(FLAGS) model_parallel = dp.DataParallel( model, device_ids=devices, ) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) # Just some step closure output def train_output_fn(outputs, ctx, args, tracker): if ctx.step > 0 and args.log_steps and ctx.step % args.log_steps == 0: now_time = time.time() if hasattr(ctx, 'start_time') and ctx.start_time: per_step_time = (now_time - ctx.start_time) / ( ctx.step - ctx.last_step_timed ) steps_per_second = 1 / per_step_time print( f'[{xm.get_ordinal()}] Round-trip step time: ' f'{per_step_time} seconds, steps per second: {steps_per_second}' ) if tracker: _train_update( device=device, step=ctx.step, loss=outputs[0], tracker=tracker, epoch=epoch, writer=writer, ) print(f'BEGIN Train step {ctx.step}') ctx.start_time = time.time() ctx.last_step_timed = ctx.step else: ctx.start_time = time.time() ctx.last_step_timed = ctx.step ctx.step += 1 def train_loop_fn(model, loader, device=None, context=None): lr_adder = 0.0 if _MSE_LOSS: loss_fn = nn.MSELoss() else: loss_fn = nn.NLLLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD( model.parameters(), lr=lr + lr_adder, momentum=FLAGS.momentum, ), ) tracker = xm.RateTracker() model.train() def train_inner_loop_fn(batch, ctx): step = ctx.step print(f'Step {step}') data = batch[0] target = batch[1] optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step( optimizer, barrier=False, ) if ( FLAGS.log_steps != 0 and ( FLAGS.log_steps == 1 or (step > 0 and step % FLAGS.log_steps == 0) ) ): xm.add_step_closure( _train_update, args=(device, step, loss, tracker, epoch, writer), ) if step == 0: xm.master_print(f"End TRAIN step {step}") ctx.step += 1 return [loss] step = 0 # Train print('Starting new epoch train loop... (epoch={epoch})') for step, (data, target) in enumerate(loader): if step % FLAGS.step_print_interval == 0: xm.master_print(f"Begin TRAIN Step: {step}") context.step = step if not FLAGS.use_autograph: outputs = train_inner_loop_fn((data, target), context) else: outputs = ptwse.flow.runner.maybe_run_converted( train_inner_loop_fn, (data, target), context ) xm.master_print(f"Saving model...") _save_checkpoint(FLAGS, device, None, model, is_epoch=True) xm.master_print(f"Model saved") def test_loop_fn(model, loader, device, context): print("***********************") print("ENTERING TEST FUNCTION") print("***********************") print('Evaluating...') total_samples = 0 correct = 0 model.eval() for step, (data, target) in enumerate(loader): if step >= FLAGS.test_max_step: break output = model(data) pred = output.max(1, keepdim=True)[1] if FLAGS.mp: correct += pred.eq(target.view_as(pred)).sum() else: correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] if FLAGS.mp: this_accuracy = 100.0 * correct.item() / total_samples print("CALLING: mesh_reduce('test_accuracy')") this_accuracy = xm.mesh_reduce( 'test_accuracy', this_accuracy, np.mean ) print("BACK FROM: mesh_reduce('test_accuracy')") else: this_accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, this_accuracy) print("***********************") print("LEAVING TEST FUNCTION") print("***********************") return this_accuracy # # Set up for # accuracy = 0.0 num_devices = ( len(xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 ) if not FLAGS.steps_per_epoch: num_training_steps_per_epoch = train_dataset_len // ( FLAGS.batch_size * num_devices ) else: num_training_steps_per_epoch = FLAGS.steps_per_epoch max_accuracy = 0.0 # # Epoch loop # for epoch in range(1, FLAGS.num_epochs + 1): # # Train # device = xm.xla_device() ctx = dp.Context(device=device) ctx.tracker = xm.RateTracker() ctx.step = 0 train_loop_fn(model, train_loader, device, ctx) # # Test # if FLAGS.run_test: with ptwse.scope.proxy_disabled(disabled=FLAGS.test_off_proxy): accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print( 'Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy) ) global_step = (epoch - 1) * num_training_steps_per_epoch max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, global_step, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True, ) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Pass [] as device_ids to run using the PyTorch/CPU engine. torchvision_model = get_model_property('model_fn') model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) lr_scheduler = context.getattr_or( 'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer if xm.is_master_ordinal() else None)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) if lr_scheduler: lr_scheduler.step() def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None num_devices = len( xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * num_devices) for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) global_step = (epoch - 1) * num_training_steps_per_epoch test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, global_step) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy
def train_mnist(): torch.manual_seed(1) if FLAGS.fake_data: train_dataset_len = 60000 # Number of images in MNIST dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(FLAGS.datadir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_dataset_len = len(train_dataset) test_dataset = datasets.MNIST(FLAGS.datadir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, sampler=test_sampler, drop_last=FLAGS.drop_last, shuffle=False, num_workers=FLAGS.num_workers) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Scale learning rate to num cores lr = FLAGS.lr * max(len(devices), 1) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(MNIST, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD( model.parameters(), lr=lr, momentum=FLAGS.momentum)) tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 writer = test_utils.get_summary_writer(FLAGS.logdir) num_devices = len( xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * num_devices) for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) global_step = (epoch - 1) * num_training_steps_per_epoch test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, global_step) if FLAGS.metrics_debug: print(met.metrics_report()) test_utils.close_summary_writer(writer) return accuracy
def train_cifar(): print('==> Preparing data..') if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) train_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=True, download=True, transform=transform_train) test_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=False, download=True, transform=transform_test) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, sampler=test_sampler, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Pass [] as device_ids to run using the PyTorch/CPU engine. model = torchvision.models.resnet18 if FLAGS.use_torchvision else ResNet18 model_parallel = dp.DataParallel(model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(accuracies) print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy)) if FLAGS.metrics_debug: print(met.metrics_report()) return accuracy
def train_imagenet(): print("==> Preparing data..") img_dim = get_model_property("img_dim") if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64), ), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size(), ) test_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64), ), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size(), ) else: normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, "train"), transforms.Compose( [ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ] ), ) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, "val"), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose( [ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ] ), ) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True, ) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False, ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=True, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers, ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, drop_last=False, shuffle=False, num_workers=FLAGS.num_workers, ) torch.manual_seed(42) devices = ( xm.get_xla_supported_devices(max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else ["cpu"] ) print("use tpu devices", devices) # Pass [] as device_ids to run using the PyTorch/CPU engine. torchvision_model = get_model_property("model_fn") model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( "optimizer", lambda: optim.SGD( model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4, ), ) lr_scheduler = context.getattr_or( "lr_scheduler", lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, "lr_scheduler_type", None), scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None), scheduler_divide_every_n_epochs=getattr( FLAGS, "lr_scheduler_divide_every_n_epochs", None ), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer if xm.is_master_ordinal() else None, ), ) tracker = xm.RateTracker() model.train() total_samples = 0 correct = 0 top5_accuracys = 0 losses = 0 for x, (data, target) in loader: optimizer.zero_grad() output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] top5_accuracys += topk_accuracy(output, target, topk=5) loss = loss_fn(output, target) loss.backward() losses += loss.item() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: print( "[{}]({}) Loss={:.5f} Top-1 ACC = {:.2f} Rate={:.2f} GlobalRate={:.2f} Time={}".format( str(device), x, loss.item(), (100.0 * correct / total_samples).item(), tracker.rate(), tracker.global_rate(), time.asctime(), ) ) if lr_scheduler: lr_scheduler.step() return ( losses / (x + 1), (100.0 * correct / total_samples).item(), (top5_accuracys / (x + 1)).item(), ) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 top5_accuracys = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] top5_accuracys += topk_accuracy(output, target, topk=5).item() accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy, top5_accuracys accuracy = 0.0 writer = SummaryWriter(FLAGS.logdir) if FLAGS.logdir else None num_devices = len(xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * num_devices) max_accuracy = 0.0 print("train_loader_len", len(train_loader), num_training_steps_per_epoch) print("test_loader_len", len(test_loader)) for epoch in range(1, FLAGS.num_epochs + 1): global_step = (epoch - 1) * num_training_steps_per_epoch # Train evaluate metrics = model_parallel(train_loop_fn, train_loader) losses, accuracies, top5_accuracys = zip(*metrics) loss = mean(losses) accuracy = mean(accuracies) top5_accuracy = mean(top5_accuracys) print( "Epoch: {} (Train), Loss {}, Mean Top-1 Accuracy: {:.2f} Top-5 accuracy: {}".format( epoch, loss, accuracy, top5_accuracy ) ) test_utils.add_scalar_to_summary(writer, "Loss/train", loss, global_step) test_utils.add_scalar_to_summary( writer, "Top-1 Accuracy/train", accuracy, global_step ) test_utils.add_scalar_to_summary( writer, "Top-5 Accuracy/train", top5_accuracy, global_step ) # Test evaluate metrics = model_parallel(test_loop_fn, test_loader) accuracies, top5_accuracys = zip(*metrics) top5_accuracys = sum(top5_accuracys) top5_accuracy = top5_accuracys / len(test_loader) accuracy = mean(accuracies) print( "Epoch: {} (Valid), Mean Top-1 Accuracy: {:.2f} Top-5 accuracy: {}".format( epoch, accuracy, top5_accuracy ) ) test_utils.add_scalar_to_summary( writer, "Top-1 Accuracy/test", accuracy, global_step ) test_utils.add_scalar_to_summary( writer, "Top-5 Accuracy/test", top5_accuracy, global_step ) if FLAGS.metrics_debug: print(met.metrics_report()) if accuracy > max_accuracy: max_accuracy = max(accuracy, max_accuracy) torch.save( model_parallel.models[0].to("cpu").state_dict(), f"./reports/resnet50_model-{epoch}.pt", ) model_parallel.models[0].to(devices[0]) test_utils.close_summary_writer(writer) print("Max Accuracy: {:.2f}%".format(accuracy)) return max_accuracy
settings.STD, ) Flag = {} Flag['lr'] = args.lr Flag['epoch'] = args.e Flag['warm'] = args.warm Flag['batch_size'] = args.b Flag['milestones'] = settings.MILESTONES Flag['ignore_idx'] = train_data_loader.dataset.ignore_index len(train_data_loader.dataset) net = UNet(3, train_data_loader.dataset.class_num) devices = (xm.get_xla_supported_devices()) print(devices) net = dp.DataParallel(net, device_ids=devices) iter_per_epoch = len(train_data_loader) / 8 best_iou = 0 for epoch in range(1, args.e + 1): print('training epoch {}'.format(epoch)) t1 = time.time() net(train_loop_fn, train_data_loader) print(time.time() - t1) #result = net(test_loop_fn, test_data_loader) #pred_res = np.array([res[0] for res in result]) #mask_res = np.array([res[1] for res in result])
def train_unet(): print('==> Preparing data..') img_dim = 1024 normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = satelliteDataSet( os.path.join(datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ])) train_dataset_len = len(train_dataset) resize_dim = max(img_dim, 256) test_dataset = satelliteDataSet( os.path.join(datadir, 'train'), transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=num_workers) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_set_batch_size, sampler=test_sampler, shuffle=False, num_workers=num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices(max_devices=num_cores)) torchvision_model = UNet() model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=1e-4)) lr_scheduler = context.getattr_or( 'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=lr_scheduler_type, scheduler_divisor=lr_scheduler_divisor, scheduler_divide_every_n_epochs= lr_scheduler_divide_every_n_epochs, num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=None)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() data = data.permute(0, 3, 1, 2) output = model(data) print('passed through model') loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(batch_size) if x % log_steps == 0: print( 'device: {}, x: {}, loss: {}, tracker: {}, tracker_global: {} ' .format(device, x, loss.item(), tracker.rate(), tracker.global_rate())) if lr_scheduler: lr_scheduler.step() def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples #test_utils.print_test_update(device, accuracy) print('device: {}, accuracy: {}'.format(device, accuracy)) return accuracy accuracy = 0.0 num_devices = len( xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 num_training_steps_per_epoch = train_dataset_len // (batch_size * num_devices) for epoch in range(1, num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) global_step = (epoch - 1) * num_training_steps_per_epoch print('global step: {}'.format(global_step)) return accuracy
def train_unet(): # logging setup logger_name = 'train_logger' logger = initializeLogging(os.path.join(logdir, 'train_history.txt'), logger_name) # checkpointing setup checkpoint_frequency = log_steps torch.manual_seed(1) ''' train_dataset = datasets.MNIST( datadir, train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) train_dataset_len = len(train_dataset) test_dataset = datasets.MNIST( datadir, train=False, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) ''' img_dim = 1024 normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = satelliteDataSet( os.path.join(datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ])) train_dataset_len = len(train_dataset) resize_dim = max(img_dim, 256) test_dataset = satelliteDataSet( os.path.join(datadir, 'train'), transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, sampler=train_sampler, drop_last=drop_last, shuffle=False if train_sampler else True, num_workers=num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size, sampler=test_sampler, drop_last=drop_last, shuffle=False, num_workers=num_workers) maxItr = num_epochs * len(train_loader) // train_loader.batch_size + 1 devices = ( xm.get_xla_supported_devices( max_devices=num_cores) if num_cores != 0 else []) # Scale learning rate to num cores lr = 0.0001 lr = lr * max(len(devices), 1) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(UNet, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.Adam(model.parameters(), lr=lr)) tracker = xm.RateTracker() model.train() print('# of iterations: {}'.format(maxItr)) logger.info('# of iterations: {}'.format(maxItr)) optimizer.zero_grad() for x, (data, target) in enumerate(loader): data = target[0].permute(0,3,1,2) target = target[1] output = model(data) loss = loss_fn(output, target.long()) #_, preds = torch.max(output, 1) loss.backward() # backprop every log_step iterations if x % log_steps == 0: xm.optimizer_step(optimizer) optimizer.zero_grad() tracker.add(batch_size) # compute the confusion matrix and IoU #print(preds.shape) #print(target.shape) #val_conf = np.zeros((num_classes, num_classes)) #val_conf = val_conf + confusion_matrix( # target[target >= 0].view(-1).cpu().numpy(), # preds[target >= 0].view(-1).cpu().numpy()) #pos = np.sum(val_conf, 1) #res = np.sum(val_conf, 0) #tp = np.diag(val_conf) #iou = np.mean(tp / np.maximum(1, pos + res - tp)) #logger.info('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate())) print('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate())) if x % log_steps == 0: logger.info('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate())) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for data, target in loader: data = target[0].permute(0,3,1,2) target = target[1] output = model(data) #pred = output.max(1, keepdim=True)[1].float() _, preds = torch.max(output, 1) preds = preds.float() correct += preds.eq(target.view_as(preds)).sum().item() total_samples += target.shape[1]**2 print('device: {}, Running Accuracy: {}'.format(device, correct/total_samples)) accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) logger.info('TEST: device: {}, accuracy: {}'.format(device, accuracy)) return accuracy accuracy = 0.0 writer = test_utils.get_summary_writer(logdir) num_devices = len( xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 num_training_steps_per_epoch = train_dataset_len // ( batch_size * num_devices) print('total epochs: {}'.format(num_epochs)) for epoch in range(1, num_epochs + 1): print(epoch) print(train_loader) model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) logger.info('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) global_step = (epoch - 1) * num_training_steps_per_epoch if metrics_debug: print(met.metrics_report()) logger.info(met.metrics_report()) logger.info('saving checkpoint. epoch: {}'.format(epoch)) torch.save(model_parallel, os.path.join(logdir,'model_parallel_chkpt.pt')) logger.info('checkpoint saved. epoch: {}'.format(epoch)) test_utils.close_summary_writer(writer) return accuracy
def main(): parser = utils.get_args_parser_with_general_args() parser.add_argument('--one_tpu', action='store_true', help="Run on one tpu core for degugging. Makes it easy to use break points") parser.add_argument('--tpu_report', action='store_true', help="Print xla metric report") args = parser.parse_args() utils.init(args) # set seeds, init logger, prepare output directory devices = tpu_xm.get_xla_supported_devices() if args.one_tpu: devices = [devices[0]] n_tpu = len(devices) logging.info(f'Found {n_tpu} TPU cores') tokenizer = AutoTokenizer.from_pretrained(args.bert_model) tokenizer.save_pretrained(args.output_dir) args.start_epoch = utils.prepare_last_checkpoint(args.bert_model) model = AutoModelWithLMHead.from_pretrained(args.bert_model) # Only Masked Language Modeling logging.info(f"Saving initial checkpoint to: {args.output_dir}") #model.save_pretrained(args.output_dir) #TODO: why does this break? #xm.save(model.state_dict(), args.output_dir) model = tpu_dp.DataParallel(model, device_ids=devices) num_data_epochs, num_train_optimization_steps= utils.get_dataset_stats(args, n_tpu) def tpu_training_loop(model, loader, device, context): """ Called by torch_xla_py.data_parallel. This function is executed on each core of the TPU once per epoch""" param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] # one optimizer and scheduler per TPU core. Both objects are saved in `context` to be reused the next epoch optimizer = context.getattr_or( 'optimizer', AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, betas=tuple(args.betas))) # derive warmup info if args.warmup_proportion is not None: warmup_steps = int(args.warmup_proportion * num_train_optimization_steps + 0.5) elif args.warmup_steps is not None: warmup_steps = args.warmup_steps else: raise Exception('What is the warmup?? Specify either warmup proportion or steps') scheduler = context.getattr_or( 'scheduler', WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps)) tr_loss = None pbar = None if str(pbar_device) == str(device): # All threads are in sync. Use progress bar only on one of them pbar = tqdm(total=int(pbar_steps), desc=f"device {device}", dynamic_ncols=True) tracker = tpu_xm.RateTracker() model.train() for step, batch in enumerate(loader): input_ids, input_mask, segment_ids, lm_label_ids, _ = batch outputs = model(input_ids, segment_ids, input_mask, lm_label_ids) loss = outputs[0] if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.sum().backward() # for multiple tensors tracker.add(args.train_batch_size) tr_loss = loss * args.gradient_accumulation_steps if step == 0 else tr_loss + loss * args.gradient_accumulation_steps if pbar is not None: pbar.update(1) # pbar.set_description(desc=f'LR: {scheduler.get_lr()}') if (step + 1) % args.gradient_accumulation_steps == 0: tpu_xm.optimizer_step(optimizer) prev_lr = scheduler.get_last_lr()[0] scheduler.step() curr_lr = scheduler.get_last_lr()[0] if args.track_learning_rate: if pbar is not None: pbar.set_description(f"Prev LR: {prev_lr} Curr LR: {curr_lr}") optimizer.zero_grad() return tr_loss.sum().item() / step # `.item()` requires a trip from TPU to CPU, which is very slow. Use it only once per epoch= for epoch in range(args.start_epoch, args.epochs): # Load one training file into memory epoch_dataset = utils.PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer, num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory) train_sampler = RandomSampler(epoch_dataset) train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size) pbar_device = devices[0] pbar_steps = utils.compute_num_steps_in_epoch(num_samples=train_sampler.num_samples, batch_size=args.train_batch_size, grad_accum_steps=1, # the pbar steps should not take into account grad accumulation steps n_tpu=n_tpu) logging.info(f'start training, epoch {epoch} on {len(devices)} cores for {pbar_steps} steps') start = time.time() losses = model(tpu_training_loop, train_dataloader) # calls `tpu_training_loop` multiple times, once per TPU core logging.info(f'Epoch {epoch} took {round(time.time() - start, 2)} seconds. Average loss: {sum(losses)/len(losses)}') utils.save_checkpoint(model._models[0], epoch, args.output_dir) if args.tpu_report: logging.info(torch_xla._XLAC._xla_metrics_report())