def map_fn(self, index, train_dataset, dev_dataset, lr, epochs, batch_size, callbacks): if self.using_tpu is True: device = xm.xla_device() else: device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") train_loader = self.make_loader(train_dataset, batch_size, 'train') dev_loader = self.make_loader(dev_dataset, batch_size, 'dev') model = self.model.to(device) if self.using_tpu: opt = self.Opt([param for param in model.parameters() if param.requires_grad], lr=lr*xm.xrt_world_size(), weight_decay=1e-4) # hard coding else: opt = self.Opt([param for param in model.parameters() if param.requires_grad], lr=lr, weight_decay=1e-4) # hard coding loss_fn = self.Loss_fn(from_logits=True) callback_kwargs = { "model": model, "eval_dic": self.dev_eval, } for callback in callbacks: callback.train_init(**callback_kwargs) for epoch in range(epochs): if self.using_tpu: xm.rendezvous("training is starting!") if xm.is_master_ordinal(): print(f"\nepoch : {epoch+1} / {epochs}") now_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device) else: print(f"epoch : {epoch+1} / {epochs}") now_train_loader = train_loader model.train() for step, batch in enumerate(now_train_loader): logits, y, loss = self.compute_batch(model, batch, device, loss_fn, opt, phase='train') if self.using_tpu: xm.rendezvous("update is starting!") self.update(logits, y, loss, 'train', batch_size) xm.rendezvous("update is ended!") if xm.is_master_ordinal(): self.show_log(step*xm.xrt_world_size(), train_dataset, batch_size, 'train') else: self.update(logits, y, loss, 'train', batch_size) self.show_log(step, train_dataset, batch_size, 'train') if self.using_tpu: xm.rendezvous("batch is done!") if xm.is_master_ordinal(): print() else: print() model.eval() with torch.no_grad(): if self.using_tpu: now_dev_loader = pl.ParallelLoader(dev_loader, [device]).per_device_loader(device) else: now_dev_loader = dev_loader for step, batch in enumerate(now_dev_loader): logits, y, loss = self.compute_batch(model, batch, device, loss_fn, opt, phase='dev') if self.using_tpu: xm.rendezvous("update is starting!") self.update(logits, y, loss, 'dev', batch_size) xm.rendezvous("eval update is ended!") if xm.is_master_ordinal(): self.show_log(step*xm.xrt_world_size(), dev_dataset, batch_size, 'dev') else: self.update(logits, y, loss, 'dev', batch_size) self.show_log(step, dev_dataset, batch_size, 'dev') if self.using_tpu: xm.rendezvous("batch is done!") if xm.is_master_ordinal(): print() else: print() self.on_epoch_end(callbacks) if self.using_tpu: xm.rendezvous("training is over!")
def train(rank, args): print('enter train @ %s'%(rank), flush=True) args.rank = rank args.split = '' torch.manual_seed(42) save_fn = os.path.join(args.save_dir, 'checkpoint_final.pt') tokenizer = get_tokenizer(args) args.vocab_size = tokenizer._tokenizer.get_vocab_size() if not args.vocab_size else args.vocab_size train_dataset = get_dataset(args) batched_already = hasattr(train_dataset, '__getbatch__') if args.total_num_updates < 100: args.total_num_updates = len(train_dataset) * args.total_num_updates if args.warmup_updates < 1: args.warmup_updates = int(args.total_num_updates * args.warmup_updates) else: args.warmup_updates = int(args.warmup_updates) train_sampler = None if args.gpus: dist.init_process_group( 'nccl', rank=rank, world_size=args.world_size ) if args.gpus > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.gpus, rank=rank, shuffle=args.shuffle) else: rank = xm.get_ordinal() if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=args.shuffle) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size if not batched_already else None, sampler=train_sampler, pin_memory=True, shuffle=False, num_workers=args.num_workers) eval_loaders = [] if args.eval_dir: for split in args.splits.split(','): split = split.strip() eval_sampler = None if args.gpus: if args.gpus > 1: eval_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.gpus, rank=rank, shuffle=False) else: rank = xm.get_ordinal() if xm.xrt_world_size() > 1: eval_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=False) args.split = split eval_dataset = get_eval_dataset(args) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=args.batch_size if not batched_already else None, sampler=eval_sampler, pin_memory=True, shuffle=False, num_workers=args.num_workers) eval_loaders.append(eval_loader) if args.gpus: assert apex_enabled torch.cuda.set_device(rank) ########################## ## ## Model Creation ## ########################## model = get_model(args, tokenizer) model.cuda(rank) device = torch.device('cuda:'+str(rank)) ########################## ## ## Init Optimizer ## ########################## optimizer = apex.optimizers.FusedAdam( model_get_parameters(model, lr=args.lr, lw_lr_decay=args.lw_lr_decay, weight_decay=args.weight_decay, special_layer_wise_lr=args.special_layer_wise_lr, log = rank == 0, ), # use this function to set extra optimizer arguments, # see model_get_parameters betas=(0.9, 0.999), eps=1e-6, lr=args.lr, weight_decay=args.weight_decay ) model, optimizer = amp.initialize(model, optimizer, opt_level='O1') model = DDP(model) batches = train_loader else: assert tpu_enabled device = xm.xla_device() ########################## ## ## Model Creation ## ########################## model = get_model(args, tokenizer) ########################## ## ## For shared parameters, TPU requires modules to be tied after .to(device) ## So we first find the shared parameters first ## ########################## shared_parameters = {e[0]: e[1:] for e in _catalog_shared_params(model)} model.to(device) do_share_parameters_again(model, shared_parameters, log = rank == 0) ########################## ## ## Init Optimizer ## ########################## optimizer = optim.Adam( model_get_parameters(model, lr=args.lr, lw_lr_decay=args.lw_lr_decay, weight_decay=args.weight_decay ), # use this function to set extra optimizer arguments, # see model_get_parameters lr=args.lr, weight_decay=args.weight_decay ) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(args.save_dir) xm.rendezvous("load_checkpoint") # wait for all workers xm.mark_step() # tracker = xm.RateTracker() if args.restore_file: states = torch.load(args.restore_file, map_location=device) for k, v in list(states.items()): if k.startswith('module.'): del states[k] k = k[7:] states[k] = v if k.endswith('position_ids'): del states[k] states[k[:-12] + 'position_embeddings'] = v if args.gpus: states = {"module.%s"%k : v for k, v in states.items()} try: model.load_state_dict(states) except Exception as err: import traceback if rank == 0: traceback.print_exc() model.load_state_dict(states, strict=False) if rank == 0: if not os.path.exists(os.path.dirname(save_fn)): try: os.makedirs(os.path.dirname(save_fn)) except OSError as exc: # Guard against race condition if exc.errno != errno.EEXIST: raise if args.gpus: torch.save(model.state_dict(), save_fn ) else: xm.save(model.state_dict(), save_fn ) model.train() if args.anomaly_detection and rank == 0: torch.set_anomaly_enabled(True) ########################## ## ## Init LR Scheduler ## ########################## if not batched_already: args.total_num_updates = args.total_num_updates // args.batch_size args.warmup_updates = args.total_num_updates // args.batch_size args.total_num_updates = args.total_num_updates // args.world_size args.warmup_updates = args.total_num_updates // args.world_size scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_updates, num_training_steps=args.total_num_updates, ) step_i = 0 err = None tb = None #tb = SummaryWriter() try: if rank == 0: pbar = tqdm(total=args.total_num_updates, file=sys.stdout) while step_i < args.total_num_updates: if not args.gpus: batches = pl.ParallelLoader(train_loader, [device]).per_device_loader(device) n_samples = len(batches) for sample in batches: step_i += 1 if step_i > args.total_num_updates: break report_step = step_i % args.log_interval == 0 while True: # the loop only for apex Gradient Overflow optimizer.zero_grad() total_loss, log = get_loss( model, sample, args=args, device=device, gpus=args.gpus, report=report_step ) if args.gpus: default_optimizer_step = optimizer.step with amp.scale_loss(total_loss, optimizer) as scaled_loss: scaled_loss.backward() # If Amp detects an overflow, it patches optimizer.step. In other words, if optimizer.step # was left unpatched, there was no overflow, and we don't need to replay. if optimizer.step is default_optimizer_step: optimizer.step() break optimizer.step() # If an overflow was detected, "optimizer.step" is the patched call, which does # nothing but restore optimizer.step to default_optimizer_step. if rank == 0: print("Overflowed, reducing loss scale and replaying batch.", flush=True) else: total_loss.backward() xm.optimizer_step(optimizer) xm.mark_step() break scheduler.step() if report_step: if 'loss' not in log: log['loss'] = total_loss # tb.add_scalar("Loss", total_loss, step_i) for k, v in log.items(): try: dist.all_reduce(v, op=dist.reduce_op.SUM) log[k] = float(v) except Exception as e: print(v, e) pass if args.gpus: if rank == 0: pbar.set_description(format_log(log, log_formatter, tb, step_i)) else: xm.add_step_closure(_train_update, args=(log, log_formatter, tb, step_i)) if args.report_metrics: xm.master_print(met.metrics_report()) if rank == 0: pbar.update(1) if rank == 0: pbar.close() if eval_loaders: model.half() model.eval() model.cuda() for k, v in model.named_parameters(): v.requires_grad =False for split, eval_loader in zip(args.splits.split(','), eval_loaders): batches = eval_loader if rank == 0: eval_length = len(batches) if not batched_already: eval_length = eval_length // args.batch_size eval_length = eval_length // args.world_size pbar = tqdm(total=eval_length, file=sys.stdout) if not args.gpus: batches = pl.ParallelLoader(eval_loader, [device]).per_device_loader(device) with torch.no_grad(): record = OrderedDict() for sample in batches: evaluate( model, sample, args=args, device=device, record=record, gpus=args.gpus, report=False ) if rank == 0: pbar.update(1) for k, v in record.items(): try: def handle_reduce(v): if len(v.shape) == 0: dist.all_reduce(v, op=dist.reduce_op.SUM) else: L = [torch.ones_like(v) for _ in range(dist.get_world_size())] dist.all_gather(L, v) v = torch.car(L, dim=0) return v if isinstance(v, list): v = [handle_reduce(e) for e in v] else: v = handle_reduce(v) record[k] = float(v) except Exception as e: pass post_evaluate(record, args=args) import json if rank == 0: print('',flush=True) print('Test result for %s'%split, flush=True) print(json.dumps(record, indent=2),flush=True) print('',flush=True) except Exception as _err: err = _err finally: folder = os.path.split(os.path.abspath(save_fn))[0] os.makedirs(folder, exist_ok=True) if rank == 0: print("Saving to %s"%save_fn) if args.gpus: torch.save(model.state_dict(), save_fn ) if err: raise err else: xm.save(model.state_dict(), save_fn ) if err: raise err print("Saved to %s"%save_fn)
def train_mnist(flags, **kwargs): torch.manual_seed(1) if flags.fake_data: train_loader = xu.SampleGenerator( data=( torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64), ), sample_count=60000 // flags.batch_size // xm.xrt_world_size(), ) test_loader = xu.SampleGenerator( data=( torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64), ), sample_count=10000 // flags.batch_size // xm.xrt_world_size(), ) else: train_dataset = datasets.MNIST( os.path.join(flags.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]), ) test_dataset = datasets.MNIST( os.path.join(flags.datadir, str(xm.get_ordinal())), train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]), ) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=flags.batch_size, sampler=train_sampler, drop_last=flags.drop_last, shuffle=False if train_sampler else True, num_workers=flags.num_workers, ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=flags.batch_size, drop_last=flags.drop_last, shuffle=False, num_workers=flags.num_workers, ) # Scale learning rate to num cores lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer)) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce("test_accuracy", accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print("Epoch {} train begin {}".format( epoch, test_utils.now())) train_loop_fn(train_device_loader) xm.master_print("Epoch {} train end {}".format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) xm.master_print("Epoch {} test end {}, Accuracy={:.2f}".format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary(writer, epoch, dict_to_write={"Accuracy/test": accuracy}, write_xla_metrics=True) if flags.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print("Max Accuracy: {:.2f}%".format(max_accuracy)) return max_accuracy
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) device = xm.xla_device() model = get_model_property('model_fn')().to(device) writer = None if FLAGS.logdir and xm.is_master_ordinal(): writer = SummaryWriter(log_dir=FLAGS.logdir) optimizer = optim.SGD( model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = train_dataset_len // ( FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer) loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) if xm.is_master_ordinal(): print("Finished training epoch {}".format(epoch)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy = test_loop_fn(para_loader.per_device_loader(device)) test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch) if FLAGS.metrics_debug: print(met.metrics_report()) return accuracy
def is_master_ordinal(self): return xm.is_master_ordinal()
def train_imagenet(): torch.manual_seed(42) device = xm.xla_device() # model = get_model_property('model_fn')().to(device) model = create_model( FLAGS.model, pretrained=FLAGS.pretrained, num_classes=FLAGS.num_classes, drop_rate=FLAGS.drop, global_pool=FLAGS.gp, bn_tf=FLAGS.bn_tf, bn_momentum=FLAGS.bn_momentum, bn_eps=FLAGS.bn_eps, drop_connect_rate=0.2, checkpoint_path=FLAGS.initial_checkpoint, args = FLAGS).to(device) model_ema=None if FLAGS.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper # import pdb; pdb.set_trace() model_e = create_model( FLAGS.model, pretrained=FLAGS.pretrained, num_classes=FLAGS.num_classes, drop_rate=FLAGS.drop, global_pool=FLAGS.gp, bn_tf=FLAGS.bn_tf, bn_momentum=FLAGS.bn_momentum, bn_eps=FLAGS.bn_eps, drop_connect_rate=0.2, checkpoint_path=FLAGS.initial_checkpoint, args = FLAGS).to(device) model_ema = ModelEma( model_e, decay=FLAGS.model_ema_decay, device='cpu' if FLAGS.model_ema_force_cpu else '', resume=FLAGS.resume) print('==> Preparing data..') img_dim = 224 if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) # else: # normalize = transforms.Normalize( # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # train_dataset = torchvision.datasets.ImageFolder( # os.path.join(FLAGS.data, 'train'), # transforms.Compose([ # transforms.RandomResizedCrop(img_dim), # transforms.RandomHorizontalFlip(), # transforms.ToTensor(), # normalize, # ])) # train_dataset_len = len(train_dataset.imgs) # resize_dim = max(img_dim, 256) # test_dataset = torchvision.datasets.ImageFolder( # os.path.join(FLAGS.data, 'val'), # # Matches Torchvision's eval transforms except Torchvision uses size # # 256 resize for all models both here and in the train loader. Their # # version crashes during training on 299x299 images, e.g. inception. # transforms.Compose([ # transforms.Resize(resize_dim), # transforms.CenterCrop(img_dim), # transforms.ToTensor(), # normalize, # ])) # train_sampler = None # if xm.xrt_world_size() > 1: # train_sampler = torch.utils.data.distributed.DistributedSampler( # train_dataset, # num_replicas=xm.xrt_world_size(), # rank=xm.get_ordinal(), # shuffle=True) # train_loader = torch.utils.data.DataLoader( # train_dataset, # batch_size=FLAGS.batch_size, # sampler=train_sampler, # shuffle=False if train_sampler else True, # num_workers=FLAGS.workers) # test_loader = torch.utils.data.DataLoader( # test_dataset, # batch_size=FLAGS.batch_size, # shuffle=False, # num_workers=FLAGS.workers) else: train_dir = os.path.join(FLAGS.data, 'train') data_config = resolve_data_config(model, FLAGS, verbose=FLAGS.local_rank == 0) dataset_train = Dataset(train_dir) collate_fn = None if not FLAGS.no_prefetcher and FLAGS.mixup > 0: collate_fn = FastCollateMixup(FLAGS.mixup, FLAGS.smoothing, FLAGS.num_classes) train_loader = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=FLAGS.batch_size, is_training=True, use_prefetcher=not FLAGS.no_prefetcher, rand_erase_prob=FLAGS.reprob, rand_erase_mode=FLAGS.remode, interpolation='bicubic', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=FLAGS.workers, distributed=FLAGS.distributed, collate_fn=collate_fn, use_auto_aug=FLAGS.auto_augment, use_mixcut=FLAGS.mixcut, ) eval_dir = os.path.join(FLAGS.data, 'val') train_dataset_len = len(train_loader) if not os.path.isdir(eval_dir): logging.error('Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) test_loader = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size = FLAGS.batch_size, is_training=False, use_prefetcher=FLAGS.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=FLAGS.workers, distributed=FLAGS.distributed, ) writer = None start_epoch = 0 if FLAGS.output and xm.is_master_ordinal(): writer = SummaryWriter(log_dir=FLAGS.output) optimizer = create_optimizer(flags, model) lr_scheduler, num_epochs = create_scheduler(flags, optimizer) if start_epoch > 0: lr_scheduler.step(start_epoch) # optimizer = optim.SGD( # model.parameters(), # lr=FLAGS.lr, # momentum=FLAGS.momentum, # weight_decay=5e-4) num_training_steps_per_epoch = train_dataset_len // ( FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer) train_loss_fn = LabelSmoothingCrossEntropy(smoothing=flags.smoothing) validate_loss_fn = nn.CrossEntropyLoss() # loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = train_loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if model_ema is not None: model_ema.update(model) if lr_scheduler: lr_scheduler.step() if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy def test_loop_fn_ema(loader): total_samples = 0 correct = 0 model_ema.eval() for x, (data, target) in loader: output = model_ema(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.epochs + 1): para_loader = dp.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) para_loader = dp.ParallelLoader(test_loader, [device]) accuracy = test_loop_fn(para_loader.per_device_loader(device)) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) if model_ema is not None: accuracy = test_loop_fn_ema(para_loader.per_device_loader(device)) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy
def train(args, train_loader, model, device, optimizer, scheduler, epoch, f, max_seq_len): total_loss = AverageMeter() losses1 = AverageMeter() # start losses2 = AverageMeter() # end accuracies1 = AverageMeter() # start accuracies2 = AverageMeter() # end model.train() tr_loss = 0.0 t = tqdm(train_loader, disable=not xm.is_master_ordinal()) for step, d in enumerate(t): input_ids = d["input_ids"].to(device) attention_mask = d["attention_mask"].to(device) token_type_ids = d["token_type_ids"].to(device) start_position = d["start_position"].to(device) end_position = d["end_position"].to(device) sentiment_label = d["sentiment_label"].to(device) model.zero_grad() logits1, logits2 = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=None, head_mask=None) #y_true = (start_position, end_position) loss1, loss2 = loss_fn((logits1, logits2), (start_position, end_position)) #loss3 = loss_fn_sentiment(logits3, sentiment_label) loss = loss1 + loss2 #max_seq_len = 256 #loss = Closs.loss_fn(logits1, logits2, start_position, end_position,device, max_seq_len) acc1, n_position1 = get_position_accuracy(logits1, start_position) acc2, n_position2 = get_position_accuracy(logits2, end_position) total_loss.update(loss.item(), n_position1) losses1.update(loss1.item(), n_position1) losses2.update(loss2.item(), n_position2) accuracies1.update(acc1, n_position1) accuracies2.update(acc2, n_position2) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) xm.optimizer_step(optimizer) scheduler.step() model.zero_grad() print_loss = xm.mesh_reduce("loss_reduce", total_loss.avg, reduce_fn) print_acc1 = xm.mesh_reduce("acc1_reduce", accuracies1.avg, reduce_fn) print_acc2 = xm.mesh_reduce("acc2_reduce", accuracies2.avg, reduce_fn) t.set_description( f"Train E:{epoch+1} - Loss:{print_loss:0.2f} - acc1:{print_acc1:0.2f} - acc2:{print_acc2:0.2f}" ) log_ = f"Epoch : {epoch+1} - train_loss : {total_loss.avg} - \n \ train_loss1 : {losses1.avg} - train_loss2 : {losses2.avg} - \n \ train_acc1 : {accuracies1.avg} - train_acc2 : {accuracies2.avg}" f.write(log_ + "\n\n") f.flush() return total_loss.avg
def train_mnist(): torch.manual_seed(1) if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=60000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(os.path.join(FLAGS.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) test_dataset = datasets.MNIST(os.path.join(FLAGS.datadir, str(xm.get_ordinal())), train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, drop_last=FLAGS.drop_last, shuffle=False, num_workers=FLAGS.num_workers) # Scale learning rate to num cores lr = FLAGS.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: xm.add_step_closure(_train_update, args=(device, x, loss, tracker)) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) xm.master_print('Finished training epoch {}'.format(epoch)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy = test_loop_fn(para_loader.per_device_loader(device)) test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch) if FLAGS.metrics_debug: print(met.metrics_report()) test_utils.close_summary_writer(writer) return accuracy
def multi_core(index, flags): torch.manual_seed(flags['seed']) batch_size = 4 device = xm.xla_device() max_epoch = 1 #Only download X, Y on one process if not xm.is_master_ordinal(): xm.rendezvous('download_only_once') train_dataset = MRIDataset(mode='train') valid_dataset = MRIDataset(mode='validation') val_loss = [] if xm.is_master_ordinal(): xm.rendezvous('download_only_once') # XLA distributed sampler for more than 1 TPU train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=flags['batch_size'], sampler=train_sampler, num_workers=flags['num_workers'], drop_last=True) valid_sampler = torch.utils.data.distributed.DistributedSampler( valid_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=flags['batch_size'], sampler=valid_sampler, num_workers=flags['num_workers'], drop_last=True) model = NeuroImageModel().to(device).train() criterion = torch.nn.L1Loss(reduction='mean') # optimizer = torch.optim.Adam(model.parameters(), lr=config.hyper_params['lr'], betas=config.hyper_params['betas'], eps=1e-08) optimizer = torch.optim.SGD(model.parameters(), lr=config.hyper_params['lr'], momentum=0.9, nesterov=True) train_start = time.time() file = open('loss', 'w') average = 0 for epoch in range(flags['num_epochs']): # Training time train_pl_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device) start = time.time() average = 0 count = 0 for batch_num, batch in enumerate(train_pl_loader): optimizer.zero_grad() print("Process", index, "saving scan") scans = batch['scans'] data = batch['data'] targets = batch['targets'] output = model(data, scans) del scans loss = criterion(output, targets) del targets loss.backward() xm.master_print( f'training: index: {batch_num} loss: {loss.item()}') count = count + 1 average = average + loss.item() xm.optimizer_step(optimizer, barrier=True) print( f'Training loss for epoch: {epoch} average of: {average/count} with count {count}' ) file.write( f'Training loss for epoch: {epoch} average of: {average/count} with count {count}' ) # average = 0 # count = 0 del loss # with torch.no_grad(): # valid_pl_loader = pl.ParallelLoader(valid_loader, [device]).per_device_loader(device) # model.eval() # for batch_num, batch in enumerate(valid_pl_loader): # scans = batch['scans'] # fnc = batch['fnc'] # sbm = batch['sbm'] # targets = batch['targets'] # output = model(fnc, sbm, scans) # del scans # validation_loss = criterion(output, targets) # del targets # xm.master_print(f'validation: index: {batch_num} loss: {validation_loss.item()}') # count = count + 1 # average = average + validation_loss.item() # val_loss.append(validation_loss) # del valid_pl_loader elapsed_train_time = time.time() - train_start print("Process", index, "finished training. Train time was:", elapsed_train_time) torch.save( f'epoch: {epoch}, state_dict: {model.state_dict()}, validation loss: {val_loss}, optimizer: {optimizer.state_dict()}', f'{config.hyper_params["model_save_path"]}/validation_loss_{time.time()}.txt' )
def run(config): def len_parallelloader(self): return len(self._loader._loader) pl.PerDeviceLoader.__len__ = len_parallelloader # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: xm.master_print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different # files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) xm.master_print('Experiment name is %s' % experiment_name) device = xm.xla_device(devkind='TPU') # Next, build the model G = model.Generator(**config) D = model.Discriminator(**config) # If using EMA, prepare it if config['ema']: xm.master_print( 'Preparing EMA for G with decay of {}'.format( config['ema_decay'])) G_ema = model.Generator(**{**config, 'skip_init': True, 'no_optim': True}) else: xm.master_print('Not using ema...') G_ema, ema = None, None # FP16? if config['G_fp16']: xm.master_print('Casting G to float16...') G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: xm.master_print('Casting D to fp16...') D = D.half() # Prepare state dict, which holds things like itr # state_dict = {'itr': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config} # If loading from a pre-trained model, load weights if config['resume']: xm.master_print('Loading weights...') utils.load_weights( G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) # move everything to TPU G = G.to(device) D = D.to(device) G.optim = optim.Adam(params=G.parameters(), lr=G.lr, betas=(G.B1, G.B2), weight_decay=0, eps=G.adam_eps) D.optim = optim.Adam(params=D.parameters(), lr=D.lr, betas=(D.B1, D.B2), weight_decay=0, eps=D.adam_eps) # for key, val in G.optim.state.items(): # G.optim.state[key]['exp_avg'] = G.optim.state[key]['exp_avg'].to(device) # G.optim.state[key]['exp_avg_sq'] = G.optim.state[key]['exp_avg_sq'].to(device) # for key, val in D.optim.state.items(): # D.optim.state[key]['exp_avg'] = D.optim.state[key]['exp_avg'].to(device) # D.optim.state[key]['exp_avg_sq'] = D.optim.state[key]['exp_avg_sq'].to(device) if config['ema']: G_ema = G_ema.to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) # Consider automatically reducing SN_eps? GD = model.G_D(G, D) xm.master_print(G) xm.master_print(D) xm.master_print('Number of params in G: {} D: {}'.format( *[sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]])) # Prepare loggers for stats; metrics holds test metrics, # lmetrics holds any desired training metrics. test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) xm.master_print( 'Test Metrics will be saved to {}'.format(test_metrics_fname)) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) xm.master_print( 'Training Metrics will be saved to {}'.format(train_metrics_fname)) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) if xm.is_master_ordinal(): # Write metadata utils.write_metadata( config['logs_root'], experiment_name, config, state_dict) # Prepare data; the Discriminator's batch size is all that needs to be passed # to the dataloader, as G doesn't require dataloading. # Note that at every loader iteration we pass in enough data to complete # a full D iteration (regardless of number of D steps and accumulations) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) xm.master_print('Preparing data...') loader = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr']}) # Prepare inception metrics: FID and IS xm.master_print('Preparing metrics...') get_inception_metrics = inception_utils.prepare_inception_metrics( config['dataset'], config['parallel'], no_inception=config['no_inception'], no_fid=config['no_fid']) # Prepare noise and randomly sampled label arrays # Allow for different batch sizes in G G_batch_size = max(config['G_batch_size'], config['batch_size']) def sample(): return utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) # Prepare a fixed z & y to see individual sample evolution throghout # training fixed_z, fixed_y = sample() train = train_fns.GAN_training_function(G, D, GD, sample, ema, state_dict, config) xm.master_print('Beginning training...') if xm.is_master_ordinal(): pbar = tqdm(total=config['total_steps']) pbar.n = state_dict['itr'] pbar.refresh() xm.rendezvous('training_starts') while (state_dict['itr'] < config['total_steps']): pl_loader = pl.ParallelLoader( loader, [device]).per_device_loader(device) for i, (x, y) in enumerate(pl_loader): if xm.is_master_ordinal(): # Increment the iteration counter pbar.update(1) state_dict['itr'] += 1 # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter # much. G.train() D.train() if config['ema']: G_ema.train() xm.rendezvous('data_collection') metrics = train(x, y) # train_log.log(itr=int(state_dict['itr']), **metrics) # Every sv_log_interval, log singular values if ((config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval']))) : if xm.is_master_ordinal(): train_log.log(itr=int(state_dict['itr']), **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')}) xm.rendezvous('Log SVs.') # Save weights and copies as configured at specified interval if (not (state_dict['itr'] % config['save_every'])): if config['G_eval_mode']: xm.master_print('Switchin G to eval mode...') G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample( G, D, G_ema, sample, fixed_z, fixed_y, state_dict, config, experiment_name) # Test every specified interval if (not (state_dict['itr'] % config['test_every'])): which_G = G_ema if config['ema'] and config['use_ema'] else G if config['G_eval_mode']: xm.master_print('Switchin G to eval mode...') which_G.eval() def G_sample(): z, y = sample() return which_G(z, which_G.shared(y)) train_fns.test( G, D, G_ema, sample, state_dict, config, G_sample, get_inception_metrics, experiment_name, test_log) # Debug : Message print # if True: # xm.master_print(met.metrics_report()) if state_dict['itr'] >= config['total_steps']: break
def _mp_fn(index, args): torch.set_default_tensor_type('torch.FloatTensor') distributed_utils.suppress_output(xm.is_master_ordinal()) main_tpu(args)
def train_model(model, criterion, optimizer, scheduler, i, class_names, metric_targets, metric_types, dataset_types, data_loaders, dataset_sizes, device, cfg, num_epochs=25, batch_size=4, patience=5, lambda_u=1.0, threshold=0.95, purpose='baseline', is_early=True): '''Train the model. Args: model (obj): the model which will be trained criterion (obj): the loss function (e.g. cross entropy) optimizer (obj): the optimizer (e.g. Adam) scheduler (obj): the learning scheduler (e.g. Step decay) i (int): the number indicating which model it is class_names (dict): class names for images (e.g. {0: 'covid-19', 1: 'pneumonia', 2: 'normal'}) metric_targets (list): metric targets to calculate performance metrics of the model (e.g. ['all', 'covid-19', 'pneumonia', 'normal']) metric_types (list): the performance metrics of the model (e.g. Accuracy, F1-Score and so on) dataset_types (list): dataset types for train and test (e.g. ['train', 'test'] or ['train', 'val', 'test]) data_loaders (list): data loaders applied transformations, the batch size and so on dataset_sizes (dict): sizes of train and test datasets device (obj): the device where the model will be trained (e.g. cpu or gpu) num_epochs (int): the number of epochs batch_size (int): the batch size patience (int): the number of patience times for early stopping lambda_u (float): the ratio of reflect unlabeled loss threshold (float): the treshold for predicted results for unlabeled data purpose (str): the purpose of the model Returns: model (obj): the model which was trained best_metrics (dict): the results of the best performance metrics after training the model ''' # Import XLA libraries for using TPUs if cfg['use_tpu']: import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl since = time.time() if is_early: early_stopping = EarlyStopping(patience=patience, verbose=True) best_metrics = {m_type: defaultdict(float) for m_type in metric_types} epoch_metrics_list = [] print(f'{"-" * 20}\nModel {i + 1}\n{"-" * 20}\n') for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Each epoch has a training and test phase for phase in dataset_types: if phase == 'train': model.train() # Set model to training mode else: if cfg['use_tpu'] and not xm.is_master_ordinal(): continue model.eval() # Set model to evaluate mode epoch_loss = 0.0 batch_metrics = { 'tp': defaultdict(int), 'size': defaultdict(int), 'fp': defaultdict(int), 'fn': defaultdict(int) } mask_ratio = [] # just for fixmatch # Create a pareallel loader if cfg['use_tpu'] and phase == 'train': # data_loaders[phase].sampler.set_epoch(epoch) final_data_loader = pl.ParallelLoader( data_loaders[phase], [device]).per_device_loader(device) else: final_data_loader = data_loaders[phase] # Iterate over data. for batch in final_data_loader: size = batch['img_lb'].size(0) # Load batches if purpose != 'baseline' and phase == 'train': inputs = torch.cat([ batch['img_lb'], batch['img_ulb'], batch['img_ulb_wa'] ], 0).to(device) else: inputs = batch['img_lb'].to(device) labels = batch['label'].to(device) del batch # zero the parameter gradients optimizer.zero_grad() # Forward the model with torch.set_grad_enabled(phase == 'train'): # Calculate labeled loss outputs = model(inputs) if purpose != 'baseline' and phase == 'train': outputs_lb = outputs[:size] outputs_ulb, outputs_ulb_wa = outputs[size:].chunk(2) del outputs else: outputs_lb = outputs _, preds = torch.max(outputs_lb, 1) loss = loss_lb = criterion(outputs_lb, labels) # Calculate unlabeled loss for FixMatch if purpose != 'baseline' and phase == 'train': probs_ulb = torch.softmax(outputs_ulb, dim=-1) probs_ulb, preds_ulb = torch.max(probs_ulb, 1) mask = probs_ulb.ge(threshold).float() if cfg['sharpening']: # using sharpening # https://github.com/LeeDoYup/FixMatch-pytorch/blob/0e0b492f1cb110a43c765c55105b5f94c13f45fd/models/fixmatch/fixmatch_utils.py#L35 # sharpen_output = torch.softmax(outputs_ulb/cfg['temperature'], dim=-1) # log_pred = F.log_softmax(outputs_ulb_wa, dim=-1) # loss_sharpen = (torch.sum(-sharpen_output*log_pred, dim=1) * mask).mean() if cfg['focal_loss']: sharpen_probs_ulb = torch.softmax( outputs_ulb / cfg['temperature'], dim=-1) log_pred = F.log_softmax(outputs_ulb_wa, dim=-1) loss_ulb = torch.sum(-sharpen_probs_ulb * log_pred, dim=1) pt = torch.exp(-loss_ulb) loss_ulb = (((1 - pt)**cfg['gamma'] * loss_ulb)).mean() else: sharpen_label = torch.softmax( outputs_ulb / cfg['temperature'], dim=-1) log_pred = F.log_softmax(outputs_ulb_wa, dim=-1) loss_ulb = torch.sum(-sharpen_probs_ulb * log_pred, dim=1).mean() loss += loss_ulb * lambda_u else: # pseudo label if cfg['focal_loss']: # Focal loss loss_ulb = F.cross_entropy(outputs_ulb_wa, preds_ulb, reduction='none') pt = torch.exp(-loss_ulb) loss_ulb = (( (1 - pt)**cfg['gamma'] * loss_ulb) * mask).mean() else: # Previous loss loss_ulb = (F.cross_entropy(outputs_ulb_wa, preds_ulb, reduction='none') * mask).mean() mask_ratio.append(mask.mean().item()) loss += loss_ulb * lambda_u # backward + optimize only if in training phase if phase == 'train': loss.backward() if cfg['use_tpu']: xm.optimizer_step(optimizer) else: optimizer.step() # Calculate loss and metrics per the batch if purpose == 'baseline' or phase == 'test': epoch_loss += loss.item() * size else: # FixMatch epoch_loss += loss_lb.item() *size\ + loss_ulb.item() * lambda_u * size * cfg['mu'] if not cfg['use_tpu'] or cfg['use_tpu'] and phase != 'train': batch_metrics = update_batch_metrics( batch_metrics, preds, labels, class_names) if phase == 'train' and scheduler: scheduler.step() # Calcluate the metrics (e.g. Accuracy) per the epoch if not cfg['use_tpu'] or cfg['use_tpu'] and phase != 'train': epoch_metrics = get_epoch_metrics(epoch_loss, dataset_sizes, phase, class_names, batch_metrics, metric_types) print_metrics(epoch_metrics, metric_targets, cfg, phase=phase, mask_ratio=mask_ratio) # Add prediction results per the epoch if phase != 'train': epoch_metrics_list.append(epoch_metrics) # Check early stopping if phase == 'test' and is_early: early_stopping(epoch_metrics['loss']['all'], model) if early_stopping.early_stop: print("Early stopping!!") break if not cfg['use_tpu'] or cfg[ 'use_tpu'] and phase != 'train' and xm.is_master_ordinal(): # Extract best case index best_acc = (-1, -1) # (idx, acc) for e_met_idx, e_met in enumerate(epoch_metrics_list): if e_met['acc']['all'] > best_acc[1]: best_acc = ((e_met_idx, e_met['acc']['all'])) best_acc_idx = best_acc[0] # Set best metrics based on recent 5 epochs metrics for metric_type in metric_types: # e.g. ['acc', 'ppv', ...] for metric_target in metric_targets: # e.g. ['all', 'covid-19', ...] # Accuracy couldn't calculate for each class if metric_type == 'acc' and metric_target in class_names: continue best_metrics[metric_type][metric_target] = \ epoch_metrics_list[best_acc_idx][metric_type][metric_target] print_metrics(best_metrics, metric_targets, cfg, phase='Best results') time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('-' * 20, '\n') return model, best_metrics
def train_mnist(flags, training_started=None, dynamic_graph=False, fetch_often=False): torch.manual_seed(1) if flags.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=600000 // flags.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=100000 // flags.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(os.path.join(flags.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) test_dataset = datasets.MNIST(os.path.join(flags.datadir, str(xm.get_ordinal())), train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=flags.batch_size, sampler=train_sampler, drop_last=flags.drop_last, shuffle=False if train_sampler else True, num_workers=flags.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=flags.batch_size, drop_last=flags.drop_last, shuffle=False, num_workers=flags.num_workers) # Scale learning rate to num cores lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() server = xp.start_server(flags.profiler_port) def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): if dynamic_graph: # testing purpose only: dynamic batch size and graph. index = max(-step, -flags.batch_size + 1) # non-empty data, target = data[:-index, :, :, :], target[:-index] if step >= 15 and training_started: # testing purpose only: set event for synchronization. training_started.set() with xp.StepTrace('train_mnist', step_num=step): with xp.Trace('build_graph'): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) if fetch_often: # testing purpose only: fetch XLA tensors to CPU. loss_i = loss.item() tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer)) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: with xp.StepTrace('test_mnist'): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, test_utils.now())) train_loop_fn(train_device_loader) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if flags.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def train_imagenet(): print("==> Preparing data..") img_dim = get_model_property("img_dim") if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64), ), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size(), ) if FLAGS.validate: test_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64), ), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size(), ) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, "train"), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]), ) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) if FLAGS.validate: test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, "val"), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ]), ) train_sampler, test_sampler = None, None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) if FLAGS.validate: test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers, ) if FLAGS.validate: test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, drop_last=FLAGS.drop_last, shuffle=False, num_workers=FLAGS.num_workers, ) device = xm.xla_device() model = get_model_property("model_fn")().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, "lr_scheduler_type", None), scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None), scheduler_divide_every_n_epochs=getattr( FLAGS, "lr_scheduler_divide_every_n_epochs", None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer, ) loss_fn = nn.CrossEntropyLoss() scaler = GradScaler() def train_loop_fn(loader, epoch): if FLAGS.fine_grained_metrics: epoch_start_time = time.time() step_latency_tracker, bwd_latency_tracker, fwd_latency_tracker = [], [], [] else: tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): if FLAGS.fine_grained_metrics: step_start_time = time.time() optimizer.zero_grad() if FLAGS.fine_grained_metrics: fwd_start_time = time.time() with autocast(): output = model(data) loss = loss_fn(output, target) if FLAGS.fine_grained_metrics: fwd_end_time = time.time() fwd_latency = fwd_end_time - fwd_start_time bwd_start_time = time.time() scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() if lr_scheduler: lr_scheduler.step() if FLAGS.fine_grained_metrics: bwd_end_time = time.time() bwd_latency = bwd_end_time - bwd_start_time step_latency = bwd_end_time - step_start_time step_latency_tracker.append(step_latency) bwd_latency_tracker.append(bwd_latency) fwd_latency_tracker.append(fwd_latency) else: tracker.add(FLAGS.batch_size) if step % FLAGS.log_steps == 0: if FLAGS.fine_grained_metrics: print('FineGrainedMetrics :: Epoch={} Step={} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\ epoch, step, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker))) else: # _train_update(device, step, loss, tracker, epoch, writer) xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer)) if FLAGS.fine_grained_metrics: epoch_end_time = time.time() epoch_latency = epoch_end_time - epoch_start_time print('FineGrainedMetrics :: Epoch={} Epoch(s)={:.} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\ epoch, epoch_latency, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker))) def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] if step % FLAGS.log_steps == 0: test_utils.print_test_update(device, None, epoch, step) # xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step)) accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce("test_accuracy", accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) if FLAGS.validate: test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, FLAGS.num_epochs + 1): xm.master_print("Epoch {} train begin {}".format( epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print("Epoch {} train end {}".format(epoch, test_utils.now())) if FLAGS.validate: accuracy = test_loop_fn(test_device_loader, epoch) xm.master_print("Epoch {} test end {}, Accuracy={:.2f}".format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, epoch, dict_to_write={"Accuracy/test": accuracy}, write_xla_metrics=True) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) if FLAGS.validate: xm.master_print("Max Accuracy: {:.2f}%".format(max_accuracy)) return max_accuracy if FLAGS.validate else None
def main(index): parser = argparse.ArgumentParser() ## Required parameters parser.add_argument("--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file).") parser.add_argument("--reload_data_file", default=None, type=int, help="Reload dataset every X epoch") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) ## Other parameters parser.add_argument( "--eval_data_file", default=None, type=str, help= "An optional input evaluation data file to evaluate the perplexity on (a text file)." ) parser.add_argument("--model_type", default="bert", type=str, help="The model architecture to be fine-tuned.") parser.add_argument( "--model_name_or_path", default="bert-base-cased", type=str, help="The model checkpoint for weights initialization.") parser.add_argument( "--mlm", action='store_true', help= "Train with masked-language modeling loss instead of language modeling." ) parser.add_argument( "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss") parser.add_argument( "--config_name", default="", type=str, help= "Optional pretrained config name or path if not the same as model_name_or_path" ) parser.add_argument( "--tokenizer_name", default="", type=str, help= "Optional pretrained tokenizer name or path if not the same as model_name_or_path" ) parser.add_argument("--tokenizer_class", default="", type=str, help="Optional pretrained tokenizer clas") parser.add_argument( "--cache_dir", default="", type=str, help= "Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)" ) parser.add_argument( "--block_size", default=-1, type=int, help="Optional input sequence length after tokenization." "The training dataset will be truncated in block of this size for training." "Default to the model max input length for single sentence inputs (take into account special tokens)." ) parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--evaluate_during_training", action='store_true', help="Run evaluation during training at each logging step.") parser.add_argument('--eval_steps', type=int, default=100, help="Evaluate every X updates steps.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation.") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for optimizer.") parser.add_argument("--sgd", action='store_true', help="Use SGD instead of Adam.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--max_steps", default=-1, type=int, help= "If > 0: set total number of training steps to perform. Override num_train_epochs." ) parser.add_argument("--warmup_samples", default=0, type=int, help="Linear warmup over warmup_samples.") parser.add_argument("--lr_decay", action='store_true', help="Decay LR using get_linear_schedule_with_warmup.") parser.add_argument( "--lr_cosine", action='store_true', help="LR using get_cosine_with_hard_restarts_schedule_with_warmup.") parser.add_argument( "--unfreeze_level", default=-1, type=int, help="If > 0: freeze all layers except few first and last.") parser.add_argument('--logging_steps', type=int, default=50, help="Log every X updates steps.") parser.add_argument('--save_steps', type=int, default=50, help="Save checkpoint every X updates steps.") parser.add_argument( '--save_total_limit', type=int, default=None, help= 'Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default' ) parser.add_argument( "--eval_all_checkpoints", action='store_true', help= "Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number" ) parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") parser.add_argument('--overwrite_output_dir', action='store_true', help="Overwrite the content of the output directory") parser.add_argument( '--overwrite_cache', action='store_true', help="Overwrite the cached training and evaluation sets") parser.add_argument('--first_run', action='store_true', help="Cache init") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit/mixed precision instead of 32-bit") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") args = parser.parse_args() args.local_rank = index if args.model_type in ["bert", "roberta", "distilbert"] and not args.mlm: raise ValueError( "BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm " "flag (masked language modeling).") if args.eval_data_file is None and args.do_eval: raise ValueError( "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " "or remove the --do_eval argument.") if os.path.exists(args.output_dir) and os.listdir( args.output_dir ) and args.do_train and not args.overwrite_output_dir: raise ValueError( f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." ) # Setup distant debugging if needed if args.server_ip and args.server_port: # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script import ptvsd print("Waiting for debugger attach") ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.wait_for_attach() args.n_gpu = xm.xrt_world_size() args.device = xm.xla_device() # Setup logging logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO if xm.is_master_ordinal() else logging.WARN) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s", args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1)) # Set seed # That is actually very important in case of distributed environment (like TPU). You need same dataset on every node/process. # If you have randomness in dataset creation (like I do) you need to set the same seed in every process. set_seed(args) config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] if os.path.exists(os.path.join(args.output_dir, WEIGHTS_NAME)): args.model_name_or_path = args.output_dir else: args.first_run = True # load model from web in single thread or file will be corrupted. lock = FileLock("the.lock") if args.first_run else contextlib.suppress() with lock: config = config_class.from_pretrained( args.config_name if args.config_name else args.model_name_or_path) if args.tokenizer_class: tokenizer_class = globals()[args.tokenizer_class] tokenizer = tokenizer_class.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) if args.block_size <= 0: args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model args.block_size = min(args.block_size, tokenizer.max_len_single_sentence) model = model_class.from_pretrained( args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) if args.fp16: model = model2half(model) model = model.to(args.device) # see https://github.com/pytorch/xla/issues/1245 model.tie_weights() def req_len(model): return len([ param for item in flatten_model(model) for param in item.parameters() if param.requires_grad ]) # freeze all layers but few first and last if args.unfreeze_level >= 0: b_req_len = req_len(model) flat = flatten_model(model) flat = [item for item in flat if list(item.parameters())] i_start = 3 i_end = 1 need_grads = set(flat[:i_start + args.unfreeze_level * 3]) | set( flat[-(i_end + args.unfreeze_level * 3):]) for item in flat: requires_grad(item, item in need_grads) log_info( f"Num of layers before {b_req_len}, after freeze {req_len(model)}") log_info("Training/evaluation parameters %s", args) # Training if args.do_train: train(args, model, tokenizer)
def tpu_training_loop(index): torch.set_default_tensor_type('torch.FloatTensor') #To decrease exploing RAM usage, only load and transfer one model at time lock_file = "tpu.lock" fd = open(lock_file, "w") fcntl.lockf(fd, fcntl.LOCK_EX) model_class = GPT2LMHeadModel model = model_class.from_pretrained("gpt2") tokenizer_class = GPT2Tokenizer tokenizer = tokenizer_class.from_pretrained("gpt2", do_lower_case=False) device = xm.xla_device() logger_is_me = False if xm.is_master_ordinal(): logger_is_me = True from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() special_tokens = { "additional_special_tokens": [ "<TITLE_START>", "<TITLE_END>", "<INSTR_START>", "<NEXT_INSTR>", "<INSTR_END>", "<INGR_START>", "<NEXT_INGR>", "<INGR_END>", "<RECIPE_START>", "<RECIPE_END>", "<INPUT_START>", "<INPUT_END>", "<NEXT_INPUT>" ] } tokenizer.add_special_tokens(special_tokens) model.resize_token_embeddings(len(tokenizer)) train_dataset = TextDataset(file_path="train") test_dataset = TextDataset(file_path="test") train_sampler = DistributedSampler(train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = DistributedSampler(test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) #PARAMS!! train_batch_size = 4 test_batch_size = 4 train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_batch_size) test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=test_batch_size) model.train().to(device) import gc gc.collect() fcntl.lockf(fd, fcntl.LOCK_UN) gradient_steps = 1 epochs = 1 t_total = len(train_dataloader) // gradient_steps no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }, { 'params': [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] # one optimizer and scheduler per TPU core. Both objects are saved in `context` to be reused the next epoch lr = 5e-5 * xm.xrt_world_size() optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=t_total) tracker = xm.RateTracker() # PARAMS V2!!! gradient_steps = 1 logging_steps = 100 validation_steps = 1000 optimizer.zero_grad() def single_epoch(big_step, epoch): train_sampler.set_epoch(epoch) para_loader = pl.ParallelLoader(train_dataloader, [device]) for step, batch in enumerate(para_loader.per_device_loader(device)): inputs, labels = (batch, batch) model.train() outputs = model(inputs, labels=labels) loss = outputs[0] loss = loss / gradient_steps loss.backward() tracker.add(1) if (step + 1) % gradient_steps == 0: xm.optimizer_step(optimizer) scheduler.step() optimizer.zero_grad() big_step += 1 if logger_is_me and (big_step + 1) % logging_steps == 0: xm.add_step_closure(_train_update, args=(device, big_step, loss, tracker, scheduler, writer)) if (big_step + 1) % validation_steps == 0: perplexity = evaluate(model, test_dataloader, device) if logger_is_me: print("Validation loss: ", perplexity) writer.add_scalar("Validation loss", perplexity, big_step) return big_step big_step = 0 #Always pretend to have one more epoch to do, otherwise model won't get saved for i in range(1, 6): print("Epoch: " + str(i)) big_step = single_epoch(big_step, i) if logger_is_me: output_dir = "gpt2-refined-epoch-" + str(i) if not os.path.exists(output_dir): os.makedirs(output_dir) save_model(model, output_dir) tokenizer.save_pretrained(output_dir) print("Model saved")
def log_info(*args, **kwargs): if xm.is_master_ordinal(): logger.info(*args, **kwargs)
def train(args, train_dataset, model, tokenizer): is_master = xm.is_master_ordinal() """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, num_workers=8, drop_last=True) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt")): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) # Train! Total optimization steps logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info("= %d", t_total) global_step = 1 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): try: # set global_step to gobal_step of last saved checkpoint from model path checkpoint_suffix = args.model_name_or_path.split("-")[-1].split( "/")[0] global_step = int(checkpoint_suffix) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: logger.info(" Starting fine-tuning.") tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) if is_master else range( epochs_trained, int(args.num_train_epochs)) # Added here for reproductibility set_seed(args) for _ in train_iterator: para_loader = pl.ParallelLoader(train_dataloader, [args.device]) epoch_iterator = tqdm( para_loader.per_device_loader(args.device), desc="Iteration", ) if is_master else para_loader.per_device_loader(args.device) for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() # batch = tuple(t.to(args.device) for t in batch) inputs = { "input_ids": batch[0], "attention_mask": batch[1], "token_type_ids": batch[2], "start_positions": batch[3], "end_positions": batch[4], } if args.model_type in [ "xlm", "roberta", "distilbert", "camembert" ]: del inputs["token_type_ids"] if args.model_type in ["xlnet", "xlm"]: inputs.update({"cls_index": batch[5], "p_mask": batch[6]}) if args.version_2_with_negative: inputs.update({"is_impossible": batch[7]}) if hasattr(model, "config") and hasattr( model.config, "lang2id"): inputs.update({ "langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device) }) outputs = model(**inputs) # model outputs are always tuple in transformers (see doc) loss = outputs[0] if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) xm.optimizer_step(optimizer) scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 # Log metrics if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: # Only evaluate when single GPU otherwise metrics may not average well if args.local_rank == -1 and args.evaluate_during_training: results = evaluate(args, model, tokenizer) for key, value in results.items(): tb_writer.add_scalar("eval_{}".format(key), value, global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) logging_loss = tr_loss # # Save model checkpoint # if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: # output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) # if not os.path.exists(output_dir): # os.makedirs(output_dir) # # Take care of distributed/parallel training # model_to_save = model.module if hasattr(model, "module") else model # model_to_save.save_pretrained(output_dir) # tokenizer.save_pretrained(output_dir) # # torch.save(args, os.path.join(output_dir, "training_args.bin")) # logger.info("Saving model checkpoint to %s", output_dir) # # torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) # torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) # logger.info("Saving optimizer and scheduler states to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() return global_step, tr_loss / global_step
def valid(args, valid_loader, model, device, tokenizer, epoch, f, max_seq_len): total_loss = AverageMeter() losses1 = AverageMeter() # start losses2 = AverageMeter() # end accuracies1 = AverageMeter() # start accuracies2 = AverageMeter() # end jaccard_scores = AverageMeter() model.eval() with torch.no_grad(): t = tqdm(valid_loader, disable=not xm.is_master_ordinal()) for step, d in enumerate(t): input_ids = d["input_ids"].to(device) attention_mask = d["attention_mask"].to(device) token_type_ids = d["token_type_ids"].to(device) start_position = d["start_position"].to(device) end_position = d["end_position"].to(device) sentiment_label = d["sentiment_label"].to(device) logits1, logits2 = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=None, head_mask=None) #y_true = (start_position, end_position) loss1, loss2 = loss_fn((logits1, logits2), (start_position, end_position)) loss = loss1 + loss2 #max_seq_len = 256 #loss = Closs.loss_fn(logits1, logits2, start_position, end_position,device, max_seq_len) acc1, n_position1 = get_position_accuracy(logits1, start_position) acc2, n_position2 = get_position_accuracy(logits2, end_position) total_loss.update(loss.item(), n_position1) losses1.update(loss1.item(), n_position1) losses2.update(loss2.item(), n_position2) accuracies1.update(acc1, n_position1) accuracies2.update(acc2, n_position2) jac_score = calculate_jaccard_score(features_dict=d, start_logits=logits1, end_logits=logits2, tokenizer=tokenizer) jaccard_scores.update(jac_score) print_loss = xm.mesh_reduce("vloss_reduce", total_loss.avg, reduce_fn) print_jac = xm.mesh_reduce("jac_reduce", jaccard_scores.avg, reduce_fn) print_acc1 = xm.mesh_reduce("vacc1_reduce", accuracies1.avg, reduce_fn) print_acc2 = xm.mesh_reduce("vacc2_reduce", accuracies2.avg, reduce_fn) t.set_description( f"Eval E:{epoch+1} - Loss:{print_loss:0.2f} - Jac:{print_jac:0.2f} - acc1:{print_acc1:0.2f} - acc2:{print_acc2:0.2f}" ) #print("Valid Jaccard Score : ", jaccard_scores.avg) log_ = f"Epoch : {epoch+1} - valid_loss : {total_loss.avg} - \n\ valid_loss1 : {losses1.avg} - \valid_loss2 : {losses2.avg} - \n\ valid_acc1 : {accuracies1.avg} - \valid_acc2 : {accuracies2.avg} " f.write(log_ + "\n\n") f.flush() return jaccard_scores.avg, total_loss.avg
def map_fn(index, args): """ for tpu """ # Setup tpu device = xm.xla_device() args.device = device is_master = xm.is_master_ordinal() # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if is_master else logging.DEBUG, ) # Set seed set_seed(args) args.model_type = args.model_type.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config = config_class.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None, ) tokenizer = tokenizer_class.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None, ) model = model_class.from_pretrained( args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, cache_dir=args.cache_dir if args.cache_dir else None, ) model.to(args.device) logger.info("Training/evaluation parameters %s", args) # Training if args.do_train: train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False) logger.info(" data load finished! ") global_step, tr_loss = train(args, train_dataset, model, tokenizer) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) # Save the trained model and the tokenizer if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): # Create output directory if needed if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: os.makedirs(args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir) # Save a trained model, configuration and tokenizer using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` # Take care of distributed/parallel training model_to_save = model.module if hasattr(model, "module") else model model_to_save.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) # Good practice: save your training arguments together with the trained model torch.save(args, os.path.join(args.output_dir, "training_args.bin")) # Load a trained model and vocabulary that you have fine-tuned model = model_class.from_pretrained( args.output_dir) # , force_download=True) tokenizer = tokenizer_class.from_pretrained( args.output_dir, do_lower_case=args.do_lower_case) model.to(args.device)
def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( "optimizer", lambda: optim.SGD( model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4, ), ) lr_scheduler = context.getattr_or( "lr_scheduler", lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, "lr_scheduler_type", None), scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None), scheduler_divide_every_n_epochs=getattr( FLAGS, "lr_scheduler_divide_every_n_epochs", None ), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer if xm.is_master_ordinal() else None, ), ) tracker = xm.RateTracker() model.train() total_samples = 0 correct = 0 top5_accuracys = 0 losses = 0 for x, (data, target) in loader: optimizer.zero_grad() output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] top5_accuracys += topk_accuracy(output, target, topk=5) loss = loss_fn(output, target) loss.backward() losses += loss.item() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: print( "[{}]({}) Loss={:.5f} Top-1 ACC = {:.2f} Rate={:.2f} GlobalRate={:.2f} Time={}".format( str(device), x, loss.item(), (100.0 * correct / total_samples).item(), tracker.rate(), tracker.global_rate(), time.asctime(), ) ) if lr_scheduler: lr_scheduler.step() return ( losses / (x + 1), (100.0 * correct / total_samples).item(), (top5_accuracys / (x + 1)).item(), )
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler, test_sampler = None, None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, persistent_workers=True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, drop_last=FLAGS.drop_last, shuffle=False, persistent_workers=True, num_workers=FLAGS.num_workers) torch.manual_seed(42) device = xm.xla_device() model = get_model_property('model_fn')() # Wrap the model with FSDP # You may wrap all, a subset, or none of the sub-modules with inner FSDPs # - to implement ZeRO-2, wrap none of the sub-modules # - to implement ZeRO-3, wrap all of the sub-modules (nested FSDP) # - you may wrap sub-modules at different granularity (e.g. at each resnet # stage or each residual block or each conv layer). fsdp_wrap = lambda m: FSDP(m.to(device), compute_dtype=getattr(torch, FLAGS.compute_dtype ), fp32_reduce_scatter=FLAGS.fp32_reduce_scatter, flatten_parameters=FLAGS.flatten_parameters) # Apply gradient checkpointing to sub-modules if specified grad_ckpt_wrap = checkpoint_module if FLAGS.use_gradient_checkpointing else ( lambda x: x) if FLAGS.use_nested_fsdp: # Here we apply inner FSDP at the level of child modules for ZeRO-3, which # corresponds to different stages in resnet (i.e. Stage 1 to 5). for submodule_name, submodule in model.named_children(): if sum(p.numel() for p in submodule.parameters()) == 0: # Skip those submodules without parameters (i.e. no need to shard them) continue # Note: wrap with `checkpoint_module` first BEFORE wrapping with FSDP m_fsdp = fsdp_wrap(grad_ckpt_wrap(getattr(model, submodule_name))) setattr(model, submodule_name, m_fsdp) # Always wrap the base model with an outer FSDP model = fsdp_wrap(model) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.WarmupAndExponentialDecayScheduler( optimizer, num_steps_per_epoch=num_training_steps_per_epoch, divide_every_n_epochs=FLAGS.lr_scheduler_divide_every_n_epochs, divisor=FLAGS.lr_scheduler_divisor, num_warmup_epochs=FLAGS.num_warmup_epochs, summary_writer=writer) loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() optimizer.step() # do not reduce gradients on sharded params tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if step % FLAGS.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer)) def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] if step % FLAGS.log_steps == 0: xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step)) accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, FLAGS.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) run_eval = ((not FLAGS.test_only_at_end and epoch % FLAGS.eval_interval == 0) or epoch == FLAGS.num_epochs) if run_eval: accuracy = test_loop_fn(test_device_loader, epoch) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def train(rank, args): print('enter train @ %s' % (rank), flush=True) args.rank = rank torch.manual_seed(42) tokenizer = get_tokenizer(args) args.vocab_size = tokenizer._tokenizer.get_vocab_size() train_dataset = get_dataset(args) if args.total_num_updates < 100: args.total_num_updates = len(train_dataset) * args.total_num_updates if args.warmup_updates < 1: args.warmup_updates = int(args.total_num_updates * args.warmup_updates) else: args.warmup_updates = int(args.warmup_updates) train_sampler = None if args.gpus: dist.init_process_group('nccl', rank=rank, world_size=args.world_size) if args.gpus > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.gpus, rank=rank, shuffle=False) else: rank = xm.get_ordinal() if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size if not hasattr(train_dataset, '__getbatch__') else None, sampler=train_sampler, pin_memory=True, shuffle=False, num_workers=args.num_workers) eval_loader = None if args.eval_dir: eval_sampler = None if args.gpus: dist.init_process_group('nccl', rank=rank, world_size=args.world_size) if args.gpus > 1: traieval_samplern_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.gpus, rank=rank, shuffle=False) else: rank = xm.get_ordinal() if xm.xrt_world_size() > 1: eval_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=False) eval_dataset = get_eval_dataset(args) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=args.batch_size if not hasattr(train_dataset, '__getbatch__') else None, sampler=eval_sampler, pin_memory=True, shuffle=False, num_workers=args.num_workers) if args.gpus: assert apex_enabled torch.cuda.set_device(rank) ########################## ## ## Model Creation ## ########################## model = get_model(args) model.cuda(rank) device = torch.device('cuda:' + str(rank)) ########################## ## ## Init Optimizer ## ########################## optimizer = apex.optimizers.FusedAdam( model_get_parameters(model, lr=args.lr, lw_lr_decay=args.lw_lr_decay, weight_decay=args.weight_decay), # use this function to set extra optimizer arguments, # see model_get_parameters betas=(0.9, 0.999), eps=1e-6, lr=args.lr, weight_decay=args.weight_decay) model, optimizer = amp.initialize(model, optimizer, opt_level='O1') model = DDP(model) batches = train_loader else: assert tpu_enabled device = xm.xla_device() ########################## ## ## Model Creation ## ########################## model = get_model(args) ########################## ## ## For shared parameters, TPU requires modules to be tied after .to(device) ## So we first find the shared parameters first ## ########################## shared_parameters = { e[0]: e[1:] for e in _catalog_shared_params(model) } model.to(device) do_share_parameters_again(model, shared_parameters, log=rank == 0) ########################## ## ## Init Optimizer ## ########################## optimizer = optim.Adam( model_get_parameters(model, lr=args.lr, lw_lr_decay=args.lw_lr_decay, weight_decay=args.weight_decay), # use this function to set extra optimizer arguments, # see model_get_parameters lr=args.lr, weight_decay=args.weight_decay) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(args.save_dir) xm.rendezvous("load_checkpoint") # wait for all workers xm.mark_step() # tracker = xm.RateTracker() if args.restore_file: states = torch.load(args.restore_file, map_location=device) for k, v in list(states.items()): if k.startswith('module.'): del states[k] k = k[7:] states[k] = v if k.endswith('position_ids'): del states[k] states[k[:-12] + 'position_embeddings'] = v try: model.load_state_dict(states) except Exception as err: import traceback traceback.print_exc() model.load_state_dict(states, strict=False) model.train() if args.anomaly_detection and rank == 0: torch.set_anomaly_enabled(True) ########################## ## ## Init LR Scheduler ## ########################## scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_updates, num_training_steps=args.total_num_updates, ) step_i = 0 err = None try: if rank == 0: pbar = tqdm(total=args.total_num_updates) while step_i < args.total_num_updates: if not args.gpus: batches = pl.ParallelLoader(train_loader, [device]).per_device_loader(device) for sample in batches: step_i += 1 if step_i > args.total_num_updates: break report_step = step_i % args.log_interval == 0 while True: # the loop only for apex Gradient Overflow optimizer.zero_grad() total_loss, log = get_loss(model, sample, args=args, device=device, gpu=args.gpus, report=report_step) if args.gpus: default_optimizer_step = optimizer.step with amp.scale_loss(total_loss, optimizer) as scaled_loss: scaled_loss.backward() # If Amp detects an overflow, it patches optimizer.step. In other words, if optimizer.step # was left unpatched, there was no overflow, and we don't need to replay. if optimizer.step is default_optimizer_step: optimizer.step() break optimizer.step( ) # If an overflow was detected, "optimizer.step" is the patched call, which does # nothing but restore optimizer.step to default_optimizer_step. if rank == 0: print( "Overflowed, reducing loss scale and replaying batch.", flush=True) else: total_loss.backward() xm.optimizer_step(optimizer) xm.mark_step() break scheduler.step() if report_step: if 'loss' not in log: log['loss'] = total_loss if args.gpus: if rank == 0: pbar.set_description(format_log( log, log_formatter)) else: xm.add_step_closure(_train_update, args=(log, log_formatter)) if args.report_metrics: xm.master_print(met.metrics_report()) if rank == 0: pbar.update(1) if eval_loader is not None: model.eval() if not args.gpus: batches = pl.ParallelLoader(eval_loader, [device]).per_device_loader(device) with torch.no_grad(): record = OrderedDict() for sample in batches: evaluate(model, sample, args=args, device=device, record=record, gpu=args.gpus, report=report_step) post_evaluate(record, args=args) import json print('', flush=True) print(json.dumps(record), flush=True) print('', flush=True) except Exception as _err: err = _err finally: save_fn = os.path.join(args.save_dir, 'checkpoint_final.pt') folder = os.path.split(os.path.abspath(save_fn))[0] os.makedirs(folder, exist_ok=True) if rank == 0 and args.gpus: torch.save(model.state_dict(), save_fn) if err: raise err else: xm.save(model.state_dict(), save_fn) if err: raise err
def train(i, num_gpus, rank, group_name, output_directory, epochs, learning_rate, sigma, iters_per_checkpoint, batch_size, seed, fp16_run, checkpoint_path, with_tensorboard): torch.manual_seed(seed) #torch.cuda.manual_seed(seed) #=====START: ADDED FOR DISTRIBUTED====== if num_gpus > 1: init_distributed(rank, num_gpus, group_name, **dist_config) #=====END: ADDED FOR DISTRIBUTED====== device = xm.xla_device() criterion = WaveGlowLoss(sigma) model = WaveGlow(**waveglow_config) #.cuda() #=====START: ADDED FOR DISTRIBUTED====== if num_gpus > 1: model = apply_gradient_allreduce(model) #=====END: ADDED FOR DISTRIBUTED====== optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) if fp16_run: from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level='O1') # Load checkpoint if one exists iteration = 0 if checkpoint_path != "": model, optimizer, iteration = load_checkpoint(checkpoint_path, model, optimizer) iteration += 1 # next iteration is iteration + 1 model = model.to(device) trainset = Mel2Samp(**data_config) # =====START: ADDED FOR DISTRIBUTED====== train_sampler = DistributedSampler(trainset, rank=xm.get_ordinal(), num_replicas=xm.xrt_world_size()) #if num_gpus > 1 else None # =====END: ADDED FOR DISTRIBUTED====== train_loader = DataLoader(trainset, num_workers=1, shuffle=False, sampler=train_sampler, batch_size=batch_size, pin_memory=True, drop_last=True) xla_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device) # Get shared output_directory ready if xm.is_master_ordinal(): if not os.path.isdir(output_directory): os.makedirs(output_directory) os.chmod(output_directory, 0o775) print("output directory", output_directory) if with_tensorboard and xm.is_master_ordinal(): from tensorboardX import SummaryWriter logger = SummaryWriter(os.path.join(output_directory, 'logs')) model.train() epoch_offset = max(0, int(iteration / len(train_loader))) # ================ MAIN TRAINNIG LOOP! =================== for epoch in range(epoch_offset, epochs): print("Epoch: {}".format(epoch)) for i, batch in enumerate(train_loader): optimizer.zero_grad() mel, audio = batch mel , audio= mel.to(device), audio.to(device)#torch.autograd.Variable(mel.cuda()) mel.requires_grad, audio.requires_grad = True, True #torch.autograd.Variable(audio.cuda()) outputs = model((mel, audio)) loss = criterion(outputs) if num_gpus > 1: reduced_loss = reduce_tensor(loss.data, num_gpus).item() else: reduced_loss = loss.item() if fp16_run: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if xm.is_master_ordinal() : print("{}:\t{:.9f}".format(iteration, loss.item())) xm.optimizer_step(optimizer, barrier=True) #print("{}:\t{:.9f}".format(iteration, reduced_loss)) if with_tensorboard and xm.is_master_ordinal() : logger.add_scalar('training_loss', reduced_loss, i + len(train_loader) * epoch) if (iteration % iters_per_checkpoint == 0): if xm.is_master_ordinal() : checkpoint_path = "{}/waveglow_{}".format( output_directory, iteration) save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path) iteration += 1
def train_imagenet(state_dict): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler, test_sampler = None, None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, drop_last=FLAGS.drop_last, shuffle=False, num_workers=FLAGS.num_workers) device = xm.xla_device() model = get_model_property('model_fn')() model.load_state_dict(state_dict) model = model.to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD( model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = train_dataset_len // ( FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer) loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if step % FLAGS.log_steps == 0: xm.add_step_closure( _train_update, args=(device, step, loss, tracker, epoch, writer)) def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] if step % FLAGS.log_steps == 0: xm.add_step_closure( test_utils.print_test_update, args=(device, None, epoch, step)) accuracy = 100.0 * correct.item() / total_samples # accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, FLAGS.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) if not FLAGS.test_only_at_end or epoch == FLAGS.num_epochs: accuracy = test_loop_fn(test_device_loader, epoch) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def train(args, model, tokenizer): """ Train the model """ if xm.is_master_ordinal(): tb_writer = SummaryWriterP(args.output_dir) def summary_write(*args, **kwargs): if xm.is_master_ordinal(): tb_writer.add_scalar(*args, **kwargs) args.train_batch_size = args.per_gpu_train_batch_size #* max(1, args.n_gpu) train_dataloader = build_dataloader(args, tokenizer) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [ p for n, p in model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] # Scale learning rate to num cores #args.learning_rate = args.learning_rate * xm.xrt_world_size() if args.sgd: optimizer = SGD(optimizer_grouped_parameters, lr=args.learning_rate) else: optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) warmup_steps = args.warmup_samples // (args.train_batch_size * xm.xrt_world_size()) if args.lr_decay: scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps=warmup_steps, t_total=t_total) elif args.lr_cosine: scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer, warmup_steps=warmup_steps, t_total=t_total, cycles=args.num_train_epochs) else: scheduler = WarmupZeroSchedule(optimizer, warmup_steps=warmup_steps) # Train! tracker = xm.RateTracker() log_info("***** Running training *****") log_info(" Num Epochs = %d", args.num_train_epochs) log_info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) log_info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (xm.xrt_world_size() if args.local_rank != -1 else 1)) log_info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) log_info(" Total optimization steps = %d", t_total) try: with open(os.path.join(args.model_name_or_path, 'step.txt'), 'r') as c: global_step = int(c.readline()) except OSError as e: global_step = 0 moving_loss = MovingLoss(10000 // args.logging_steps) train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=not xm.is_master_ordinal()) try: for epoch in train_iterator: p_train_dataloader = pl.ParallelLoader(train_dataloader, [args.device]) epoch_iterator = tqdm(p_train_dataloader.per_device_loader( args.device), total=len(train_dataloader), desc="Iteration", disable=not xm.is_master_ordinal()) model.train() for step, batch in enumerate(epoch_iterator): optimizer.zero_grad() inputs, labels = mask_tokens( batch, tokenizer, args) if args.mlm else (batch, batch) outputs = model( inputs, masked_lm_labels=labels) if args.mlm else model( inputs, labels=labels) loss = outputs[ 0] # model outputs are always tuple in pytorch-transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) xm.optimizer_step(optimizer, barrier=True) scheduler.step() global_step += 1 tracker.add(args.train_batch_size) if args.logging_steps > 0 and global_step % args.logging_steps == 0: ls = loss.item( ) # weird. if you call loss.item() only in one process, the whole thing hangs. So call on every and log in one. moving_loss.add(ls) summary_write('lr', scheduler.get_last_lr()[0], global_step) epoch_iterator.set_postfix( MovingLoss=f'{moving_loss.loss:.2f}', Perplexity= f'{torch.exp(torch.tensor(moving_loss.loss)):.2f}') if args.save_steps > 0 and global_step % args.save_steps == 0: save_state(args, model, tokenizer, global_step) #if step >= 1023: # TPU seems to like consistent epoch lenght # epoch_iterator.close() # break if args.max_steps > 0 and step > args.max_steps: epoch_iterator.close() break # evaluate once in an epoch if args.evaluate_during_training: results = evaluate(args, model, tokenizer, f"checkpoint-{global_step}") log_info(f"Eval {results}") for key, value in results.items(): summary_write("eval_{}".format(key), value, global_step) # reload dataset every args.reload_data_file epochs if args.reload_data_file and (epoch + 1) % args.reload_data_file == 0: train_dataloader = build_dataloader(args, tokenizer) # that's very slow on TPU #print_sample(model, tokenizer, args.device, args) except (KeyboardInterrupt, SystemExit): save_state(args, model, tokenizer, global_step) raise save_state(args, model, tokenizer, global_step) return global_step, moving_loss.loss
def is_local_master(self) -> bool: if is_tpu_available(): return xm.is_master_ordinal(local=True) else: return self.args.local_rank in [-1, 0]
def summary_write(*args, **kwargs): if xm.is_master_ordinal(): tb_writer.add_scalar(*args, **kwargs)
def train(train_loader, model, optimizer, scheduler, epoch, args, DEVICE): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model = model.train().to(DEVICE) loader = pl.ParallelLoader(train_loader, [DEVICE]).per_device_loader(DEVICE) # noise2net = Res2Net(epsilon=0.50, hidden_planes=16, batch_size=args.batch_size).train().to(DEVICE) end = time.time() for i, (images, target) in enumerate(loader): # measure data loading time data_time.update(time.time() - end) bx = images by = target print("Zero grad") optimizer.zero_grad() # with torch.no_grad(): # if random.random() < 0.5: # batch_size = bx.shape[0] # noise2net.reload_parameters() # noise2net.set_epsilon(random.uniform(args.noisenet_max_eps / 2.0, args.noisenet_max_eps)) # bx = bx.reshape((1, batch_size * 3, 224, 224)) # bx = noise2net(bx) # bx = bx.reshape((batch_size, 3, 224, 224)) print("Forward") logits = model(bx) print("Cross Entropy") loss = F.cross_entropy(logits, by) # measure accuracy and record loss output, target = logits, by acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) print("Backward") loss.backward() print("Step") xm.optimizer_step(optimizer) print("Scheduler step") scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0 and xm.is_master_ordinal(): progress.display(i)
def run(): """ Main function to setup the training loop and evaluation loop. See comments for detailed explanation. Returns: None, but it saves the model weights and model performance, based on the get_map_fn arguments """ # xla will assign a device for each forked run of this function device = xm.xla_device() # determine if this fork is the master fork to avoid logging and print 8 times master = xm.is_master_ordinal() if master: logger.info("running at batch size %i" % batch_size) criterion = nn.CrossEntropyLoss() criterion.to(device) model = WRAPPED_MODEL.to(device) # standard data prep CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] train_transform = transforms.Compose( [ transforms.ToTensor(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ] ) if args.cutout > 0: train_transform.transforms.append(Cutout(args.cutout)) train_data = CifarDataset(transform=train_transform) # distributed samples ensure data is sharded to each tpu core # if you do not use this, you are only using 1 of the 8 cores train_sampler = torch.utils.data.distributed.DistributedSampler( train_data, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True, ) train_queue = torch.utils.data.DataLoader( train_data, batch_size=batch_size//xm.xrt_world_size(), sampler=train_sampler, drop_last=True, num_workers=0, ) valid_transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ] ) valid_data = my_cifar10.CIFAR10( root=data_root, train=False, download=False, transform=valid_transform ) valid_sampler = torch.utils.data.distributed.DistributedSampler( valid_data, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False, ) valid_queue = torch.utils.data.DataLoader( valid_data, sampler=valid_sampler, batch_size=batch_size//xm.xrt_world_size(), drop_last=True, num_workers=0, ) # standard optimizer stuff parameters = filter(lambda p: p.requires_grad, model.parameters()) if args.opt == "sgd": optimizer = torch.optim.SGD( parameters, args.learning_rate, momentum=momentum, weight_decay=args.weight_decay, ) elif args.opt == "lamb": optimizer = Lamb( parameters, lr=args.learning_rate, weight_decay=weight_decay ) else: raise NameError("Unknown Optimizer %s" % args.opt) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(epochs)) # training by epoch loop for epoch in range(epochs): # the model needs a droprate, so just assign it model.droprate = drop_path_prob * epoch / epochs start = datetime.datetime.now() st = start.strftime("%Y-%m-%d %H:%M:%S") if master: logger.info("starting epoch %i at %s" % (epoch, st)) # parallel loader necessary to load data in parallel to each core para_loader = pl.ParallelLoader(train_queue, [device]).per_device_loader( device ) correct, train_loss, total = train( para_loader, model, criterion, optimizer, params, device ) train_acc = 100 * correct / total # collect the train accuracies from all cores train_acc = xm.mesh_reduce("avg acc", train_acc, np.mean) end = datetime.datetime.now() duration = (end - start).total_seconds() if master: logger.info("train_acc %f duration %f" % (train_acc, duration)) scheduler.step() # validate using 8 cores and collect results valid_acc, valid_obj = infer(valid_queue, model, criterion, device) valid_acc = xm.mesh_reduce("val avg acc", valid_acc, np.mean) if master: logger.info("valid_acc %f" % valid_acc) # count flops _ = add_flops_counting_methods(model) model.eval() model.start_flops_count() random_data = torch.randn(1, 3, 32, 32) model(torch.autograd.Variable(random_data).to(device)) n_flops = np.round(model.compute_average_flops_cost() / 1e6, 4) n_flops = xm.mesh_reduce("flops", n_flops, np.mean) if master: logger.info("flops %f" % n_flops) if master: logger.info("saving") # save weights and results xm.save([valid_acc, n_flops], "results.pt")