def __init__(self, path, net, run_config, init=True, measure_latency=None, no_gpu=False): self.path = path self.net = net self.run_config = run_config self.best_acc = 0.0 self.start_epoch = 0 os.makedirs(self.path, exist_ok=True) self.device = xm.xla_device() self.net = xmp.MpModelWrapper(self.net).to(self.device) # initialize model (default) if init: init_models(self.net, run_config.model_init) # net info net_info = get_net_info(self.net, self.run_config.data_provider.data_shape, measure_latency, True) with open('%s/net_info.txt' % self.path, 'w') as fout: fout.write(json.dumps(net_info, indent=4) + '\n') # noinspection PyBroadException try: fout.write(self.network.module_str + '\n') except Exception: pass fout.write('%s\n' % self.run_config.data_provider.train.dataset.transform) fout.write('%s\n' % self.run_config.data_provider.test.dataset.transform) fout.write('%s\n' % self.network) # criterion if isinstance(self.run_config.mixup_alpha, float): self.train_criterion = cross_entropy_loss_with_soft_target elif self.run_config.label_smoothing > 0: self.train_criterion = lambda pred, target: \ cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing) else: self.train_criterion = nn.CrossEntropyLoss() self.test_criterion = nn.CrossEntropyLoss() # optimizer if self.run_config.no_decay_keys: keys = self.run_config.no_decay_keys.split('#') net_params = [ self.net.get_parameters(keys, mode='exclude'), # parameters with weight decay self.net.get_parameters(keys, mode='include'), # parameters without weight decay ] else: # noinspection PyBroadException try: net_params = self.network.weight_parameters() except Exception: net_params = [] for param in self.network.parameters(): if param.requires_grad: net_params.append(param) self.optimizer = self.run_config.build_optimizer(net_params)
def get_net(name, pretrained=False): if name not in nets.keys(): net = GeneralizedCassavaClassifier(name, pretrained=pretrained) else: net = nets[name](pretrained=pretrained) if config.USE_TPU: net = xmp.MpModelWrapper(net) return net
def __init__(self, model, loss_fn_class, optimizer_class, metrics): self.using_tpu = True if "COLAB_TPU_ADDR" in os.environ or "TPU_IP_ADDRESS" in os.environ else False if self.using_tpu: self.model = xmp.MpModelWrapper(model) else: self.model = model self.Loss_fn = loss_fn_class self.Opt = optimizer_class self.metrics = metrics self.train_eval = {"train_loss": AverageMeter("train_loss")} for key in metrics.keys(): self.train_eval[f"train_{key}"] = AverageMeter(f"train_{key}") self.dev_eval = {"dev_loss": AverageMeter("dev_loss")} for key in metrics.keys(): self.dev_eval[f"dev_{key}"] = AverageMeter(f"dev_{key}")
def pack_learner_args(self: Learner): "pack learner args into dict to pass to spawned process" learner_args = {**self.__stored_args__} learner_args['wrapped_model'] = xmp.MpModelWrapper(self.model) learner_args['base_dls'] = self.dls # fetch only cbs not in defaults if ProgressCallback not in defaults.callbacks: defaults.callbacks.append(ProgressCallback) default_cbs = [cls() for cls in defaults.callbacks] learner_args['cbs'] = [ cb for cb in self.cbs if cb.name not in L(default_cbs).attrgot('name') ] learner_args['master_cbs'] = self.master_cbs # remove extra args from learner args (in __stored_args__ but not in init args) add_args = {} for arg in _extra_args: if arg in learner_args: add_args[arg] = learner_args.pop(arg) return learner_args, add_args
def connect(self, model: "pl.LightningModule") -> None: TPUSpawnStrategy._validate_patched_dataloaders(model) self.wrapped_model = xmp.MpModelWrapper( LightningDistributedModule(model)) return super().connect(model)
self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = self.bn1(x) x = F.relu(F.max_pool2d(self.conv2(x), 2)) x = self.bn2(x) x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) # Only instantiate model weights once in memory. WRAPPED_MODEL = xmp.MpModelWrapper(MNIST()) def train_mnist(): torch.manual_seed(1) """ tpu 를 쓴다하면 dataset 에 할 일 train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) """ def get_dataset():
import cv2 import pandas as pd import torch import torch.nn as nn import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl import torch_xla.distributed.xla_multiprocessing as xmp from torch.utils.data import Dataset from config import args from models import ShopeeModel SERIAL_EXEC = xmp.MpSerialExecutor() # Only instantiate model weights once in memory. WRAPPED_MODEL = xmp.MpModelWrapper(ShopeeModel()) class ShopeeDataset(Dataset): def __init__(self, data, root_dir=args.train_dir, transform=args.train_args, train=True): self.data = data self.root_dir = root_dir self.transform = transform self.train = train def __len__(self): return len(self.data)
) # val_loader = 1 xm.master_print( f">>> Training examples on 1 core: {len(train_loader) * global_config.TPU_BATCH_SIZE}" ) torch.set_default_tensor_type('torch.FloatTensor') device = xm.xla_device() net.model = net.model.to(device) net.fit(train_loader, val_loader) net = EfficientNet_Model(device="TPU", config=global_config, steps=100) # Continue training proc if global_config.CONTINUE_TRAIN: net.load(global_config.CONTINUE_TRAIN) xm.master_print( f">>> {global_config.CONTINUE_TRAIN} is LOADED for resuming trianing!") net.model = xmp.MpModelWrapper( net.model) # wrap the model for seamlessly distrubuted training xm.master_print(">>> Ready to fit Train Set...") # xmp.spawn() should be wrapped in "__name__ == __main__()" if __name__ == '__main__': # torch.multiprocessing.freeze_support() FLAGS = {} xmp.spawn(_mp_fn, args=(FLAGS, ), nprocs=8, start_method='fork')
from src.config import * from src.utils import * from src.loss import FocalCosineLoss, SmoothCrossEntropyLoss, bi_tempered_logistic_loss if USE_TPU: import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.distributed.parallel_loader as pl os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000" df = pd.read_csv(TRAIN_FOLDS) dataloader = get_infer_dataloader(infer=df) device = get_device(n=0) net = get_net(name=NET, pretrained=False) net.load_state_dict(torch.load("../input/model-weights/SEResNeXt50_32x4d_BH_fold_2_11.bin", map_location=torch.device('cpu'))) net = xmp.MpModelWrapper(net) if USE_TPU else net net = net.to(device) preds = np.empty((0, 5), dtype=np.float64) for images, labels in tqdm(dataloader): images, labels = images.to(device), labels.to(device) predictions = net(images).detach().cpu().numpy() preds = np.concatenate([preds, predictions], axis=0) print(preds.shape) ids = df["image_id"].to_numpy().reshape(-1, 1) preds = np.concatenate([ids, preds], axis=1) print(preds.shape) preds = pd.DataFrame(preds, columns=['id', '0', '1', '2', '3', '4']) preds.to_csv("SEResNeXt50_32x4d_BH_2_preds.csv", index=False)
def main(rank): #Seed - Added for TPU purposes torch.manual_seed(1) #Create log folder root = 'result_fg/' model = 'coco_model_' result_folder_name = 'images_' + FLAGS['log_dir'] model_folder_name = 'models_' + FLAGS['log_dir'] if not os.path.isdir(root): os.mkdir(root) if not os.path.isdir(root + result_folder_name): os.mkdir(root + result_folder_name) if not os.path.isdir(root + model_folder_name): os.mkdir(root + model_folder_name) #Save the script copyfile(os.path.basename(__file__), root + result_folder_name + '/' + os.path.basename(__file__)) #Define transformation for dataset images - e.g scaling transform = transforms.Compose( [ transforms.Scale((FLAGS['img_size'],FLAGS['img_size'])), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] ) #Load dataset category_names = FLAGS['category_names'].split(',') #Serial Executor - This is needed to spread inside TPU for memory purposes SERIAL_EXEC = xmp.MpSerialExecutor() #Define Dataset dataset = SERIAL_EXEC.run( lambda: CocoData( root = FLAGS['train_imgs_path'], annFile = FLAGS['train_annotation_path'], category_names = category_names, transform=transform, final_img_size=FLAGS['img_size'] ) ) #Discard images contain very small instances dataset.discard_small(min_area=0.03, max_area=1) #Define data sampler - Added for TPU purposes train_sampler = DistributedSampler( dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True ) #Define data loader train_loader = DataLoader( #Modified for TPU purposes dataset, batch_size=FLAGS['batch_size'], sampler=train_sampler, num_workers=FLAGS['num_workers'] # shuffle=True ) #Define device - Added for TPU purposes device = xm.xla_device(devkind='TPU') #For evaluation define fixed masks and noises data_iter = iter(train_loader) sample_batched = data_iter.next() x_fixed = sample_batched['image'][0:FLAGS['num_test_img']] x_fixed = Variable(x_fixed.to(device)) y_fixed = sample_batched['single_fg_mask'][0:FLAGS['num_test_img']] y_fixed = Variable(y_fixed.to(device)) z_fixed = torch.randn((FLAGS['num_test_img'],FLAGS['noise_size'])) z_fixed = Variable(z_fixed.to(device)) #Define networks generator = Generator_FG( z_dim=FLAGS['noise_size'], label_channel=len(category_names), num_res_blocks=FLAGS['num_res_blocks'] ) discriminator_glob = Discriminator( channels=3+len(category_names) ) discriminator_instance = Discriminator( channels=3+len(category_names), input_size=FLAGS['local_patch_size'] ) WRAPPED_GENERATOR = xmp.MpModelWrapper(generator) #Added for TPU purposes WRAPPED_DISCRIMINATOR_GLOB = xmp.MpModelWrapper(discriminator) #Added for TPU purposes WRAPPED_DISCRIMINATOR_INSTANCE = xmp.MpModelWrapper(discriminator) #Added for TPU purposes G_fg = WRAPPED_GENERATOR.to(device) #Modified for TPU purposes D_glob = WRAPPED_DISCRIMINATOR.to(device) #Modified for TPU purposes D_instance = WRAPPED_DISCRIMINATOR.to(device) #Modified for TPU purposes #Load parameters from pre-trained models if FLAGS['pre_trained_model_path'] != None and FLAGS['pre_trained_model_epoch'] != None: try: G_fg.load_state_dict(xser.load(FLAGS['pre_trained_model_path'] + 'G_fg_epoch_' + FLAGS['pre_trained_model_epoch'])) D_glob.load_state_dict(xser.load(FLAGS['pre_trained_model_path'] + 'D_glob_epoch_' + FLAGS['pre_trained_model_epoch'])) D_instance.load_state_dict(xser.load(FLAGS['pre_trained_model_path'] + 'D_local_epoch_' + FLAGS['pre_trained_model_epoch'])) xm.master_print('Parameters are loaded!') except: xm.master_print('Error: Pre-trained parameters are not loaded!') pass #Define interpolation operation up_instance = nn.Upsample( size=(FLAGS['local_patch_size'],FLAGS['local_patch_size']), mode='bilinear' ) #Define pooling operation for the case that image size and local patch size are mismatched pooling_instance = nn.Sequential() if FLAGS['local_patch_size']!=FLAGS['img_size']: pooling_instance.add_module( '0', nn.AvgPool2d(int(FLAGS['img_size']/FLAGS['local_patch_size'])) ) #Define training loss function - binary cross entropy BCE_loss = nn.BCELoss() #Define feature matching loss criterionVGG = VGGLoss() criterionVGG = criterionVGG.to(device) #Modified for TPU Purposes #Define optimizer G_local_optimizer = optim.Adam( G_fg.parameters(), lr=FLAGS['lr'], betas=(0.0, 0.9) ) D_local_optimizer = optim.Adam( list(filter(lambda p: p.requires_grad, D_glob.parameters())) + list(filter(lambda p: p.requires_grad, D_instance.parameters())), lr=FLAGS['lr'], betas=(0.0,0.9) ) #Deine learning rate scheduler scheduler_G = lr_scheduler.StepLR( G_local_optimizer, step_size=FLAGS['optim_step_size'], gamma=FLAGS['optim_gamma'] ) scheduler_D = lr_scheduler.StepLR( D_local_optimizer, step_size=FLAGS['optim_step_size'], gamma=FLAGS['optim_gamma'] ) #----------------------------TRAIN----------------------------------------- xm.master_print('training start!') tracker = xm.RateTracker() #Added for TPU reasons start_time = time.time() for epoch in range(FLAGS['train_epoch']): epoch_start_time = time.time() para_loader = pl.ParallelLoader(train_loader, [device]) #Added for TPU purposes loader = para_loader.per_device_loader(device) #Added for TPU purposes D_local_losses = [] G_local_losses = [] y_real_ = torch.ones(FLAGS['batch_size']) y_fake_ = torch.zeros(FLAGS['batch_size']) y_real_ = Variable(y_real_.to(device)) #Modified for TPU purposes y_fake_ = Variable(y_fake_.to(device)) #Modified for TPU purposes data_iter = iter(loader) num_iter = 0 while num_iter < len(loader): #Modified for TPU purposes j=0 while j < FLAGS['critic_iter'] and num_iter < len(loader): j += 1 sample_batched = data_iter.next() num_iter += 1 x_ = sample_batched['image'] x_ = Variable(x_.to(device)) #Modified for TPU purposes y_ = sample_batched['single_fg_mask'] y_ = Variable(y_.to(device)) #Modified for TPU purposes fg_mask = sample_batched['seg_mask'] fg_mask = Variable(fg_mask.to(device)) #Modified for TPU purposes y_instances = sample_batched['mask_instance'] bbox = sample_batched['bbox'] mini_batch = x_.size()[0] if mini_batch != FLAGS['batch_size']: break #Update discriminators - D #Real examples D_glob.zero_grad() D_instance.zero_grad() y_reduced = torch.sum(y_,1).clamp(0,1).view(y_.size(0),1,FLAGS['img_size'],FLAGS['img_size']) x_d = torch.cat([x_,fg_mask],1) x_instances = torch.zeros((FLAGS['batch_size'],3,FLAGS['local_patch_size'],FLAGS['local_patch_size'])) x_instances = Variable(x_instances.to(device)) y_instances = Variable(y_instances.to(device)) y_instances = pooling_instance(y_instances) G_instances = torch.zeros((FLAGS['batch_size'],3,FLAGS['local_patch_size'],FLAGS['local_patch_size'])) G_instances = Variable(G_instances.to(device)) #Obtain instances for t in range(x_d.size()[0]): x_instance = x_[t,0:3,bbox[0][t]:bbox[1][t],bbox[2][t]:bbox[3][t]] x_instance = x_instance.contiguous().view(1,x_instance.size()[0],x_instance.size()[1],x_instance.size()[2]) x_instances[t] = up_instance(x_instance) D_result_instance = D_instance(torch.cat([x_instances,y_instances],1)).squeeze() D_result = D_glob(x_d).squeeze() D_real_loss = BCE_loss(D_result, y_real_) + BCE_loss(D_result_instance, y_real_) D_real_loss.backward() #Fake examples z_ = torch.randn((mini_batch,FLAGS['noise_size'])) z_ = Variable(z_.to(device)) #Generate fake images G_fg_result = G_fg(z_,y_, torch.mul(x_,(1-y_reduced))) G_result_d = torch.cat([G_fg_result,fg_mask],1) #Obtain fake instances for t in range(x_d.size()[0]): G_instance = G_result_d[t,0:3,bbox[0][t]:bbox[1][t],bbox[2][t]:bbox[3][t]] G_instance = G_instance.contiguous().view(1,G_instance.size()[0],G_instance.size()[1],G_instance.size()[2]) G_instances[t] = up_instance(G_instance) D_result_instance = D_instance(torch.cat([G_instances,y_instances],1).detach()).squeeze() D_result = D_glob(G_result_d.detach()).squeeze() D_fake_loss = BCE_loss(D_result, y_fake_) + BCE_loss(D_result_instance, y_fake_) D_fake_loss.backward() xm.optimizer_step(D_local_optimizer) #Modified for TPU purposes D_train_loss = D_real_loss + D_fake_loss D_local_losses.append(D_train_loss.data[0]) if mini_batch != FLAGS['batch_size']: break #Update generator G G_fg.zero_grad() D_result = D_glob(G_result_d).squeeze() D_result_instance = D_instance(torch.cat([G_instances,y_instances],1)).squeeze() G_train_loss = (1-FLAGS['trade_off_G'])*BCE_loss(D_result, y_real_) + FLAGS['trade_off_G']*BCE_loss(D_result_instance, y_real_) #Feature matching loss between generated image and corresponding ground truth FM_loss = criterionVGG(G_fg_result, x_) #Reconstruction loss Recon_loss = mse_loss(torch.mul(x_,(1-y_reduced) ), torch.mul(G_fg_result,(1-y_reduced)) ) total_loss = G_train_loss + FLAGS['lambda_FM']*FM_loss + FLAGS['lambda_recon']*Recon_loss total_loss.backward() xm.optimizer_step(G_local_optimizer) G_local_losses.append(G_train_loss.data[0]) xm.master_print('loss_d: %.3f, loss_g: %.3f' % (D_train_loss.data[0],G_train_loss.data[0])) if (num_iter % 100) == 0: xm.master_print('%d - %d complete!' % ((epoch+1), num_iter)) xm.master_print(result_folder_name) #Modified location of the scheduler step to avoid warning scheduler_G.step() scheduler_D.step() epoch_end_time = time.time() per_epoch_ptime = epoch_end_time - epoch_start_time xm.master_print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), FLAGS['train_epoch'], per_epoch_ptime, torch.mean(torch.FloatTensor(D_local_losses)), torch.mean(torch.FloatTensor(G_local_losses)))) #Save images G_fg.eval() if epoch == 0: show_result( (epoch+1), x_fixed, save=True, path=root + result_folder_name+ '/' + model + str(epoch + 1 ) + '_gt.png' ) for t in range(y_fixed.size()[1]): show_result( (epoch+1), y_fixed[:,t:t+1,:,:], save=True, path=root + result_folder_name+ '/' + model + str(epoch + 1 ) +'_'+ str(t) +'_masked.png' ) show_result( (epoch+1), G_fg( z_fixed, y_fixed, torch.mul( x_fixed, (1-torch.sum(y_fixed,1).view(y_fixed.size(0),1,FLAGS['img_size'],FLAGS['img_size'])) ) ), save=True, path=root + result_folder_name+ '/' + model + str(epoch + 1 ) + '_fg.png' ) G_fg.train() #Save model params if FLAGS['save_models'] and (epoch>11 and epoch % 10 == 0 ): xser.save( G_fg.state_dict(), root + model_folder_name + '/' + model + 'G_fg_epoch_'+str(epoch)+'.pth' master_only=True ) xser.save( D_glob.state_dict(), root + model_folder_name + '/' + model + 'D_glob_epoch_'+str(epoch)+'.pth' master_only=True ) xser.save( D_instance.state_dict(), root + model_folder_name + '/' + model + 'D_local_epoch_'+str(epoch)+'.pth' master_only=True ) end_time = time.time() total_ptime = end_time - start_time xm.master_print("Training finish!... save training results") xm.master_print('Training time: ' + str(total_ptime))
import os if os.environ.get('COLAB_GPU', 0) == 1: os.environ['GPU_NUM_DEVICES'] = '1' os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/usr/local/cuda/' FLAGS = {} FLAGS['batch_size'] = 4 FLAGS['num_workers'] = 4 FLAGS['learning_rate'] = 0.002 FLAGS['num_epochs'] = 20 FLAGS['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1 FLAGS['log_steps'] = 20 FLAGS['metrics_debug'] = False SERIAL_EXEC = xmp.MpSerialExecutor() net = EfficientNet.from_pretrained('efficientnet-b4',num_classes=config.num_classes) WRAPPED_MODEL = xmp.MpModelWrapper(net) def train_resnet18(): torch.manual_seed(1) def get_dataset(): paths_all,labels,cls2id = get_lists() train_dataset = MYDataset(paths_all['train'],labels['train'],config.train_transform) test_dataset = MYDataset(paths_all['val'],labels['val'],config.train_transform) return train_dataset, test_dataset train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset,
def train_and_evaluate( genome: tuple, individual=None, args: argparse.Namespace = None, first_gen: bool = True, save: str = None, client_id: str = None, ): """ Function to train and evaluate an individual using a TPU. Results are always saved in the save dir to make distributed data management easier. Args: first_gen: genome: save: individual: args: Returns: """ if args.stream == "tpu": # must warp up TPU import torch_xla auxiliary = False assert hasattr(individual, "id") if not first_gen: # this is not the first generation, so mating should have occurred assert hasattr(individual, "parents") expr_root = "" save_pth = os.path.join(expr_root, "{}".format(save)) utils.create_exp_dir(save_pth) CIFAR_CLASSES = 10 learning_rate = 0.025 momentum = 0.9 weight_decay = 3e-4 data_root = "../data" batch_size = args.batch_size auxiliary_weight = 0.4 grad_clip = 5 report_freq = 50 train_params = { "auxiliary": auxiliary, "auxiliary_weight": auxiliary_weight, "grad_clip": grad_clip, "report_freq": report_freq, } if args.search_space == "micro": genotype = micro_encoding.decode(genome) model = Network(args.init_channels, CIFAR_CLASSES, args.layers, auxiliary, genotype) if not first_gen: # change the way the weights are set up model = manage_weights(model, individual, expr_root, args) elif args.search_space == "macro": raise NotImplementedError("Not supported") else: raise NameError("Unknown search space type") logger.info("Architecture = %s", genotype) try: max_weight = args.max_weight except: print("Could Not Determine Maximum Weight Argument") max_weight = 1e20 clip = weightClip(max_weight=max_weight, min_weight=max_weight * -1) if args.stream == "tpu": from projectcode.training.tpu import get_map_fn import torch_xla.distributed.xla_multiprocessing as xmp WRAPPED_MODEL = xmp.MpModelWrapper(model) logger.info("Executing TPU Training") map_fn = get_map_fn(model, train_params, data_root, momentum, weight_decay, CIFAR_CLASSES, learning_rate, args.layers, batch_size, epochs=args.epochs, save_pth=save_pth, args=args, WRAPPED_MODEL=WRAPPED_MODEL, clip=clip) FLAGS = {} xmp.spawn(map_fn, args=(FLAGS, ), nprocs=1, start_method="fork") valid_acc, n_flops = torch.load("results.pt") elif args.stream == "gpu": from projectcode.training.gpu import train_gpu logger.info("Executing GPU Training") valid_acc, n_flops = train_gpu(model, train_params, data_root, momentum, weight_decay, CIFAR_CLASSES, learning_rate, args.layers, batch_size, epochs=args.epochs, save_pth=save_pth, args=args, clip=clip) else: raise NameError("Unrecognized client stream") n_params = (np.sum( np.prod(v.size()) for v in filter(lambda p: p.requires_grad, model.parameters())) / 1e6) if main_config.distributed_cloud and args.weight_init == "lammarckian": wt_path = f"{args.code}_{client_id}_weights_{individual.id:05d}.pt" torch.save(model.state_dict(), wt_path) blob_name = upload_blob(wt_path) else: blob_name = None torch.save(model.state_dict(), os.path.join(save_pth, "weights.pt")) result_dict = { "id": individual.id, "save_path": save_pth, "valid_acc": valid_acc, "params": n_params, "flops": n_flops, "wt_blob_name": blob_name, } dump(result_dict, os.path.join(save_pth, "result.pkl")) return result_dict
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model checkpoints and predictions will be written.") parser.add_argument("--model_name", default="ruGPT3Medium", type=str, help="ruGPT3Small or ruGPT3Medium or ruGPT3Large") parser.add_argument("--with_negative", default=False, type=str, help="version 2 with negative or not") # Other parameters parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json") parser.add_argument("--predict_file", default=None, type=str, help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") parser.add_argument("--max_seq_length", default=1000, type=int, help="The maximum total input sequence length after WordPiece tokenization. Sequences " "longer than this will be truncated, and sequences shorter than this will be padded.") parser.add_argument("--doc_stride", default=128, type=int, help="When splitting up a long document into chunks, how much stride to take between chunks.") parser.add_argument("--max_query_length", default=64, type=int, help="The maximum number of tokens for the question. Questions longer than this will " "be truncated to this length.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_predict", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% " "of training.") parser.add_argument("--n_best_size", default=20, type=int, help="The total number of n-best predictions to generate in the nbest_predictions.json " "output file.") parser.add_argument("--max_answer_length", default=30, type=int, help="The maximum length of an answer that can be generated. This is needed because the start " "and end predictions are not conditioned on one another.") parser.add_argument("--verbose_logging", action='store_true', help="If true, all of the warnings related to data processing will be printed. " "A number of warnings are expected for a normal SQuAD evaluation.") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument("--do_lower_case", action='store_true', help="Whether to lower case the input text. True for uncased models, False for cased models.") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--loss_scale', type=float, default=0, help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--null_score_diff_threshold', type=float, default=0.0, help="If null_score - best_non_null is greater than the threshold predict null.") args = parser.parse_args() print(args) os.environ['XLA_USE_BF16'] = '1' os.environ['TRIM_GRAPH_SIZE'] = '10000000' logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) # logger.info("device: {} n_gpu: {}, distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1))) if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if not args.do_train and not args.do_predict: raise ValueError( "At least one of `do_train` or `do_predict` must be True.") if args.do_train: if not args.train_file: raise ValueError( "If `do_train` is True, then `train_file` must be specified.") if args.do_predict: if not args.predict_file: raise ValueError( "If `do_predict` is True, then `predict_file` must be specified.") if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: raise ValueError( "Output directory () already exists and is not empty.") if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) tokenizer = GPT2Tokenizer.from_pretrained( "sberbank-ai/rugpt3large_based_on_gpt2") train_examples = None num_train_optimization_steps = None if args.do_train: train_examples = read_squad_examples( input_file=args.train_file, is_training=True, version_2_with_negative=args.with_negative) num_train_optimization_steps = int(len( train_examples) / args.train_batch_size / args.gradient_accumulation_steps / xm.xrt_world_size()) * args.num_train_epochs if args.local_rank != -1: num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() # Prepare model model = GPT2ModelForQuestionAnswering.from_pretrained( cache_dir=PYTORCH_PRETRAINED_GPT2_CACHE, pretrained_model_name_or_path=args.model_name) if args.local_rank != -1: try: from apex.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") model = DDP(model) if args.do_train: train_features = convert_examples_to_features( examples=train_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=True) logger.info("***** Running training *****") logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num split examples = %d", len(train_features)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_optimization_steps) all_input_ids = torch.tensor( [f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor( [f.input_mask for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor( [f.segment_ids for f in train_features], dtype=torch.long) all_start_positions = torch.tensor( [f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor( [f.end_position for f in train_features], dtype=torch.long) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) WRAPPED_MODEL = xmp.MpModelWrapper(model) def train_fn(index, train_data): device = xm.xla_device() train_sampler = torch.utils.data.distributed.DistributedSampler( train_data, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_dataloader = DataLoader( train_data, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=1, drop_last=True) global_step = 0 losses = [] w_model = WRAPPED_MODEL.to(device) # Prepare optimizer param_optimizer = list(w_model.named_parameters()) # hack to remove pooler, which is not used # thus it produce None grad that break apex param_optimizer = [ n for n in param_optimizer if 'pooler' not in n[0]] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any( nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any( nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate * xm.xrt_world_size()) w_model.train() for epoch in range(int(args.num_train_epochs)): if xm.is_master_ordinal(): print(f"\n\nEpoch n.{epoch + 1}/{args.num_train_epochs} started") device_loader = pl.ParallelLoader(train_dataloader, [device]).per_device_loader(device) pbar = tqdm(device_loader, disable=not xm.is_master_ordinal(), desc="Epoch progress", position=0, leave=True) for step, batch in enumerate(device_loader): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, start_positions, end_positions = batch loss = w_model(input_ids, segment_ids, input_mask, start_positions, end_positions) loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps losses.append(loss.item()) # print(f"loss={loss.item()}") optimizer.zero_grad() loss.backward() pbar.update(1) if step % 10 == 0 and xm.is_master_ordinal() and step: pbar.set_description(desc=f'loss : {np.mean(losses[10:])}') if xm.is_master_ordinal() and ((step + 1) % 1000 == 0) or step + 1 == len(device_loader): print(f"\navg loss for 1000 steps = {np.mean(losses)}") losses = [] xm.optimizer_step(optimizer) global_step += 1 xm.rendezvous('save_model') w_model.to("cpu") if xm.is_master_ordinal(): print("\nSaving model...") print(w_model == model) # torch_xla.utils.serialization.save(model, args.output_dir, master_only=True, global_master=True) # Save a trained model, configuration and tokenizer model_to_save = w_model.module if hasattr( model, 'module') else w_model # Only save the model it-self # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_config_file = os.path.join(args.output_dir, CONFIG_NAME) # xm.save(model, output_model_file, master_only=True, global_master=False) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) # train_fn(device_loader) xmp.spawn(train_fn, args=(train_data, ), nprocs=1, start_method='fork') if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): # Load a trained model and vocabulary that you have fine-tuned model = GPT2ModelForQuestionAnswering.from_pretrained(args.output_dir) tokenizer = GPT2Tokenizer.from_pretrained(args.output_dir) else: model = GPT2ModelForQuestionAnswering.from_pretrained( cache_dir=PYTORCH_PRETRAINED_GPT2_CACHE, pretrained_model_name_or_path=args.model_name) device = xm.xla_device() model.to(device) if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): eval_examples = read_squad_examples( input_file=args.predict_file, is_training=False, ) eval_features = convert_examples_to_features( examples=eval_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=False,) logger.info("***** Running predictions *****") logger.info(" Num orig examples = %d", len(eval_examples)) logger.info(" Num split examples = %d", len(eval_features)) logger.info(" Batch size = %d", args.predict_batch_size) all_input_ids = torch.tensor( [f.input_ids for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor( [f.input_mask for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor( [f.segment_ids for f in eval_features], dtype=torch.long) all_example_index = torch.arange( all_input_ids.size(0), dtype=torch.long) eval_data = TensorDataset( all_input_ids, all_input_mask, all_segment_ids, all_example_index) # Run prediction for full data eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader( eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size) model.eval() all_results = [] logger.info("Start evaluating") for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0], position=0, leave=True): if len(all_results) % 1000 == 0: logger.info("Processing example: %d" % (len(all_results))) input_ids = input_ids.to(device) input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) with torch.no_grad(): batch_start_logits, batch_end_logits = model( input_ids, segment_ids, input_mask) for i, example_index in enumerate(example_indices): start_logits = batch_start_logits[i].detach().cpu().tolist() end_logits = batch_end_logits[i].detach().cpu().tolist() eval_feature = eval_features[example_index.item()] unique_id = int(eval_feature.unique_id) all_results.append(RawResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits)) output_prediction_file = os.path.join( args.output_dir, "predictions.json") output_nbest_file = os.path.join( args.output_dir, "nbest_predictions.json") output_null_log_odds_file = os.path.join( args.output_dir, "null_odds.json") write_predictions(eval_examples, eval_features, all_results, args.n_best_size, args.max_answer_length, args.do_lower_case, output_prediction_file, output_nbest_file, output_null_log_odds_file, args.verbose_logging, True, args.null_score_diff_threshold)
def main(args): if args.device == 'tpu': args.distributed = False else: utils.init_distributed_mode(args) print("git:\n {}\n".format(utils.get_sha())) if args.frozen_weights is not None: assert args.masks, "Frozen training is meant for segmentation only" print(args) if args.device == 'tpu': device = xm.xla_device() else: device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) model, criterion, postprocessors = build_model(args) model.to(device) model_without_ddp = model if args.distributed: if args.device == 'tpu': model = xmp.MpModelWrapper(model) else: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad ] }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad ], "lr": args.lr_backbone, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) dataset_train = build_dataset(image_set='train', args=args) dataset_val = build_dataset(image_set='val', args=args) if args.distributed: sampler_train = DistributedSampler(dataset_train) sampler_val = DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers) data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) if args.dataset_file == "coco_panoptic": # We also evaluate AP during panoptic training, on original coco DS coco_val = datasets.coco.build("val", args) base_ds = get_coco_api_from_dataset(coco_val) else: base_ds = get_coco_api_from_dataset(dataset_val) if args.frozen_weights is not None: checkpoint = torch.load(args.frozen_weights, map_location='cpu') model_without_ddp.detr.load_state_dict(checkpoint['model']) output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 from engine import tpu_evaluate tpu_evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir) return if args.eval: test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir) if args.output_dir: utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm) lr_scheduler.step() if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] # extra checkpoint before LR drop and every 100 epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0: checkpoint_paths.append(output_dir + '/' + f'checkpoint{epoch:04}.pth') for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args, }, checkpoint_path) test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir) log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") # for evaluation logs if coco_evaluator is not None: (output_dir / 'eval').mkdir(exist_ok=True) if "bbox" in coco_evaluator.coco_eval: filenames = ['latest.pth'] if epoch % 50 == 0: filenames.append(f'{epoch:03}.pth') for name in filenames: torch.save(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval" / name) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
FLAGS["learning_rate"] = 4e-4 FLAGS["num_epochs"] = 40 FLAGS["weight_decay"] = 1e-4 FLAGS["log_steps"] = 20 FLAGS["img_size"] = IMG_SIZE FLAGS["loss"] = "focal" FLAGS["optimizer"] = "AdamW" FLAGS["scheduler"] = "ReduceLROnPlateau" FLAGS["exp_name"] = "enet_b0" FLAGS["fold"] = [0] # , 1, 2, 3, 4] FLAGS["val_freq"] = 1 FLAGS["num_cores"] = 8 FLAGS["seed"] = 42 model_cpu = EfficientNet("tf_efficientnet_b0_ns") WRAPPED_MODEL = xmp.MpModelWrapper(model_cpu) SERIAL_EXEC = xmp.MpSerialExecutor() for fold_no in FLAGS["fold"]: X_train = df_train[df_train["fold"] != fold_no][[ col for col in df_train.columns if col != "target" ]] X_val = df_train[df_train["fold"] == fold_no][[ col for col in df_train.columns if col != "target" ]] y_train = df_train[df_train["fold"] != fold_no][[ col for col in df_train.columns if col == "target" ]] y_val = df_train[df_train["fold"] == fold_no][[ col for col in df_train.columns if col == "target" ]]
transform = transforms.Compose([ transforms.Resize((FLAGS['img_size'], FLAGS['img_size'])), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) #Networks generator = Generator_BG(z_dim=FLAGS['noise_size'], label_channel=len(category_names), num_res_blocks=FLAGS['num_res_blocks'], num_res_blocks_fg=FLAGS['num_res_blocks_fg'], num_res_blocks_bg=FLAGS['num_res_blocks_bg']) discriminator_glob = Discriminator(channels=3 + len(category_names)) WRAPPED_GENERATOR = xmp.MpModelWrapper(generator) #Added for TPU purposes WRAPPED_DISCRIMINATOR = xmp.MpModelWrapper( discriminator_glob) #Added for TPU purposes def main(rank): #Modified for TPU purposes #Seed - Added for TPU purposes torch.manual_seed(1) #Define Dataset - Modified for TPU purposes dataset = SERIAL_EXEC.run( lambda: CocoData(root=FLAGS['train_imgs_path'], annFile=FLAGS['train_annotation_path'], category_names=category_names, transform=transform,
out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, 4) out = torch.flatten(out, 1) out = self.linear(out) return F.log_softmax(out, dim=1) def ResNet18(): return ResNet(BasicBlock, [2, 2, 2, 2]) SERIAL_EXEC = xmp.MpSerialExecutor() # Only instantiate model weights once in memory. WRAPPED_MODEL = xmp.MpModelWrapper(ResNet18()) def train_resnet18(): torch.manual_seed(1) def get_dataset(): norm = transforms.Normalize( mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), norm, ]) transform_test = transforms.Compose([
def map_fn(index): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = add_custom_args( HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) ) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args, args = parser.parse_args_into_dataclasses() logger.info(f"parser built") # load and instantiate tokenizer global tokenizer tokenizer = BertTokenizerFast.from_pretrained( (Path(args.icebert_folder) / (str(data_args.max_seq_length) + "_tokenizers") / (model_args.model_type + "_tokenizer"))) # load and instantiate configuration file with open(args.config_file, 'r') as fp: config_dict = json.load(fp) config_kwargs = { "cache_dir": model_args.cache_dir, "revision": model_args.model_revision, "use_auth_token": True if model_args.use_auth_token else None, } config = BertConfig(vocab_size=tokenizer.vocab_size, max_position_embeddings=data_args.max_seq_length, \ hidden_size=config_dict["hidden_size"], num_hidden_layers=config_dict["num_hidden_layers"], \ num_attention_heads=config_dict["num_attention_heads"], intermediate_size=config_dict["intermediate_size"], \ hidden_act=config_dict["hidden_act"], hidden_dropout_prob=config_dict["hidden_dropout_prob"], \ attention_probs_dropout_prob=config_dict["attention_probs_dropout_prob"], type_vocab_size=config_dict["type_vocab_size"], \ initializer_range=config_dict["initializer_range"], layer_norm_eps=config_dict["layer_norm_eps"], **config_kwargs) # load and instantiate model # IMPORTANT: the model is wrapped using the xmp.MpModelWrapper, which loads the model only once, in the global scope model = xmp.MpModelWrapper(BertForMaskedLM(config)) logger.info(f"tokenizer and model instantiated") # move model to device device = xm.xla_device() model.to(device) xm.rendezvous("Model moved to device") # prepare dataset and datacollator for on-the-fly tokenization and masking global data_files data_files = {"train": data_args.train_file} global max_len max_len = data_args.max_seq_length global cache_dir cache_dir = model_args.cache_dir tokenized_datasets = SERIAL_EXEC.run(get_tokenized_dataset) xm.rendezvous("Tokenized dataset loaded") data_collator = SERIAL_EXEC.run(get_data_collator) xm.rendezvous("DataCollator loaded") # handle possible checkpoints last_checkpoint = None if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: last_checkpoint = get_last_checkpoint(training_args.output_dir) if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. " "Use --overwrite_output_dir to overcome." ) elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: logger.info( f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." ) # select and optionally sample the train_dataset if training_args.do_train: if "train" not in tokenized_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = tokenized_datasets["train"] if data_args.max_train_samples is not None: train_dataset = train_dataset.select(range(data_args.max_train_samples)) # setup training parameters trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets["train"], tokenizer=tokenizer, data_collator=data_collator, ) # start training if training_args.do_train: checkpoint = None if training_args.resume_from_checkpoint is not None: checkpoint = training_args.resume_from_checkpoint elif last_checkpoint is not None: checkpoint = last_checkpoint logger.info("*** Starting training ***") train_result = trainer.train(resume_from_checkpoint=checkpoint) trainer.save_model() # Saves the tokenizer too for easy upload logger.info("*** Model saved ***") try: metrics = train_result.metrics logger.info("*** metrics assigned from train_result ***") max_train_samples = ( data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) ) logger.info("*** max train samples assigned ***") metrics["train_samples"] = min(max_train_samples, len(train_dataset)) logger.info("*** metrics[train_samples] assigned ***") trainer.log_metrics("train", metrics) logger.info("*** trainer.log_metrics called ***") trainer.save_metrics("train", metrics) logger.info("*** trainer.save_metrics called ***") trainer.save_state() logger.info("*** trainer.save_state called: last line in the map_fn function! ***") except: logger.warning("*** Failed to save metrics and trainer state: check the following exception: ***") traceback.print_exc()