def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]: """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) def collate(examples: List[torch.Tensor]): if tokenizer._pad_token is None: return pad_sequence(examples, batch_first=True) return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) train_sampler = DistributedSampler(train_dataset, num_replicas=bps.size(), rank=bps.rank()) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = SGD(optimizer_grouped_parameters, lr=args.learning_rate, momentum=0.9) optimizer = bps.DistributedOptimizer( optimizer, named_parameters=model.named_parameters()) bps.broadcast_parameters(model.state_dict(), root_rank=0) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if (args.model_name_or_path and os.path.isfile( os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt"))): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (bps.size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if args.model_name_or_path and os.path.exists(args.model_name_or_path): try: # set global_step to gobal_step of last saved checkpoint from model path checkpoint_suffix = args.model_name_or_path.split("-")[-1].split( "/")[0] global_step = int(checkpoint_suffix) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: logger.info(" Starting fine-tuning.") tr_loss, logging_loss = 0.0, 0.0 model_to_resize = model.module if hasattr( model, "module") else model # Take care of distributed/parallel training model_to_resize.resize_token_embeddings(len(tokenizer)) model.zero_grad() train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed(args) # Added here for reproducibility for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) inputs = inputs.to(args.device) labels = labels.to(args.device) model.train() outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model( inputs, labels=labels) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics if ( args.local_rank == -1 and args.evaluate_during_training ): # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, model, tokenizer) for key, value in results.items(): tb_writer.add_scalar("eval_{}".format(key), value, global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) logging_loss = tr_loss if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: checkpoint_prefix = "checkpoint" # Save model checkpoint output_dir = os.path.join( args.output_dir, "{}-{}".format(checkpoint_prefix, global_step)) os.makedirs(output_dir, exist_ok=True) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) _rotate_checkpoints(args, checkpoint_prefix) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() return global_step, tr_loss / global_step
optimizer = bps.DistributedOptimizer( optimizer, named_parameters=model.named_parameters(), compression=compression, backward_passes_per_step=args.batches_per_allreduce) # Restore from a previous checkpoint, if initial_epoch is specified. # BytePS: restore on the first worker which will broadcast weights to other workers. if resume_from_epoch > 0 and bps.rank() == 0: filepath = args.checkpoint_format.format(epoch=resume_from_epoch) checkpoint = torch.load(filepath) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) # BytePS: broadcast parameters & optimizer state. bps.broadcast_parameters(model.state_dict(), root_rank=0) bps.broadcast_optimizer_state(optimizer, root_rank=0) def train(epoch): model.train() train_sampler.set_epoch(epoch) train_loss = Metric('train_loss') train_accuracy = Metric('train_accuracy') with tqdm(total=len(train_loader), desc='Train Epoch #{}'.format(epoch + 1), disable=not verbose) as t: for batch_idx, (data, target) in enumerate(train_loader): adjust_learning_rate(epoch, batch_idx)
def build_model(self): """ DataLoader """ if self.fix_aug: print("FIX AUG ON") train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.Resize((self.img_size, self.img_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) else: train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.Resize((self.img_size + 30, self.img_size + 30)), transforms.RandomCrop(self.img_size), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) test_transform = transforms.Compose([ transforms.Resize((self.img_size, self.img_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) self.trainA = ImageFolder(os.path.join(self.dataset_dir, self.dataset, 'trainA'), train_transform, list_mode=self.list_mode) self.trainB = ImageFolder(os.path.join(self.dataset_dir, self.dataset, 'trainB'), train_transform, list_mode=self.list_mode) self.testA = ImageFolder(os.path.join(self.dataset_dir, self.dataset, 'testA'), test_transform, list_mode=self.list_mode) self.testB = ImageFolder(os.path.join(self.dataset_dir, self.dataset, 'testB'), test_transform, list_mode=self.list_mode) trainA_sampler = torch.utils.data.distributed.DistributedSampler( self.trainA, num_replicas=bps.size(), rank=bps.rank()) trainB_sampler = torch.utils.data.distributed.DistributedSampler( self.trainB, num_replicas=bps.size(), rank=bps.rank()) testA_sampler = torch.utils.data.distributed.DistributedSampler( self.testA, num_replicas=bps.size(), rank=bps.rank()) testB_sampler = torch.utils.data.distributed.DistributedSampler( self.testB, num_replicas=bps.size(), rank=bps.rank()) self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, sampler=trainA_sampler, num_workers=1) self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, sampler=trainB_sampler, num_workers=1) self.testA_loader = DataLoader(self.testA, batch_size=1, sampler=testA_sampler) self.testB_loader = DataLoader(self.testB, batch_size=1, sampler=testB_sampler) """ Define Generator, Discriminator """ self.genA2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch, n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device) self.genB2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch, n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device) self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device) self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device) self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device) self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device) """ Define Loss """ self.L1_loss = nn.L1Loss().to(self.device) self.MSE_loss = nn.MSELoss().to(self.device) self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device) gen_named_parameters = [] dis_named_parameters = [] for n, p in (list(self.genA2B.named_parameters(prefix='genA2B')) + list(self.genB2A.named_parameters(prefix='genB2A'))): gen_named_parameters.append((n, p)) for n, p in (list(self.disGA.named_parameters(prefix='disGA')) + list(self.disGB.named_parameters(prefix='disGB')) + list(self.disLA.named_parameters(prefix='disLA')) + list(self.disLB.named_parameters(prefix='disLB'))): dis_named_parameters.append((n, p)) gen_state_dict = OrderedDict( [("genA2B." + k, v) for k, v in self.genA2B.state_dict().items()] + [("genB2A." + k, v) for k, v in self.genB2A.state_dict().items()]) dis_state_dict = OrderedDict( [("disGA." + k, v) for k, v in self.disGA.state_dict().items()] + [("disGB." + k, v) for k, v in self.disGB.state_dict().items()] + [("disLA." + k, v) for k, v in self.disLA.state_dict().items()] + [("disLB." + k, v) for k, v in self.disLB.state_dict().items()]) bps.broadcast_parameters(gen_state_dict, root_rank=0) bps.broadcast_parameters(dis_state_dict, root_rank=0) """ Trainer """ self.G_optim = torch.optim.Adam(itertools.chain( self.genA2B.parameters(), self.genB2A.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay) self.D_optim = torch.optim.Adam(itertools.chain( self.disGA.parameters(), self.disGB.parameters(), self.disLA.parameters(), self.disLB.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay) named_parameters = [] for n, p in list(self.genA2B.named_parameters()): named_parameters.append(("genA2B." + n, p)) for n, p in list(self.genB2A.named_parameters()): named_parameters.append(("genB2A." + n, p)) self.G_optim = bps.DistributedOptimizer( self.G_optim, named_parameters=gen_named_parameters, compression=bps.Compression.none) self.D_optim = bps.DistributedOptimizer( self.D_optim, named_parameters=dis_named_parameters, compression=bps.Compression.none) self.G_optim._handles.clear() self.D_optim._handles.clear() """ Define Rho clipper to constraint the value of rho in AdaILN and ILN""" self.Rho_clipper = RhoClipper(0, 1)