def train(args, train_dataset, model, tokenizer): """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs args.warmup_steps = t_total // 100 # Prepare optimizer and schedule (linear warmup and decay) optimizer_grouped_parameters = get_param_groups(args, model) optimizer = RAdam(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: # model = torch.nn.DataParallel(model) model = DataParallelModel(model) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) args.logging_steps = len(train_dataloader) // 1 args.save_steps = args.logging_steps global_step = 0 tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange(int(args.num_train_epochs), desc="Epoch") set_seed(args) for _ in train_iterator: args.current_epoch = _ epoch_iterator = tqdm(train_dataloader, desc="Iteration") for step, batch in enumerate(epoch_iterator): model.train() batch = tuple(t.to(args.device) for t in batch) inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, } # XLM and RoBERTa don't use segment_ids # 'labels': batch[3]} outputs = model(**inputs) outputs = [outputs[i][0] for i in range(len(outputs))] loss_fct = CrossEntropyLoss() loss_fct = DataParallelCriterion(loss_fct) loss = loss_fct(outputs, batch[3]) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() scheduler.step() model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, model, tokenizer) for key, value in results.items(): tb_writer.add_scalar('eval_{}'.format(key), value, global_step) tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step) logging_loss = tr_loss if args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join( args.output_dir, 'checkpoint-{}'.format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = model.module if hasattr( model, 'module' ) else model # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, 'training_args.bin')) logger.info("Saving model checkpoint to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() return global_step, tr_loss / global_step
class LMTrainer: def __init__(self, model, mask_prob: float = 0.15, clip: int = 1, optimizer=None): self.model = model self.clip = clip self.optimizer = optimizer self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.model = self.model.to(self.device) self.mask_prob = mask_prob self.criterion = nn.NLLLoss( ignore_index=model.text_processor.pad_token_id()) num_gpu = torch.cuda.device_count() if num_gpu > 1: print("Let's use", num_gpu, "GPUs!") self.model = DataParallelModel(self.model) self.criterion = DataParallelCriterion(self.criterion) self.best_dev_loss = float("inf") self.best_train_loss = float("inf") self.last_train_loss = float("inf") def train_epoch(self, data_iter: data_utils.DataLoader, dev_data_iter: data_utils.DataLoader, saving_path: str, step: int): "Standard Training and Logging Function" start = time.time() total_tokens, total_loss, tokens, cur_loss = 0, 0, 0, 0 cur_loss = 0 model = self.model.module if hasattr(self.model, "module") else self.model for i, batch in enumerate(data_iter): if self.optimizer is not None: self.optimizer.zero_grad() mask, target, texts = mask_text(self.mask_prob, batch["pad_mask"], batch["texts"], model.text_processor) try: predictions = self.model(mask=mask, texts=texts, pads=batch["pad_mask"], langs=batch["langs"]) ntokens = target.size(0) if ntokens == 0: # Nothing to predict! continue loss = self.criterion(predictions, target).mean() loss.backward() unmask_text(mask, target, texts) if self.optimizer is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) self.optimizer.step() step += 1 loss = float(loss.data) * ntokens total_loss += loss cur_loss += loss total_tokens += ntokens tokens += ntokens if step % 50 == 0: elapsed = time.time() - start print( datetime.datetime.now(), "Epoch Step: %d Loss: %f Tokens per Sec: %f" % (step, cur_loss / tokens, tokens / elapsed)) if step % 500 == 0: self.validate_and_save(saving_path, dev_data_iter) start, tokens, cur_loss = time.time(), 0, 0 except RuntimeError as err: print("Problem with batch item", texts.size()) torch.cuda.empty_cache() pass current_loss = total_loss / total_tokens print("Total loss in this epoch: %f" % current_loss) if current_loss < self.best_train_loss: self.best_train_loss = current_loss model_to_save = (self.model.module if hasattr( self.model, "module") else self.model) model_to_save.save(saving_path + ".latest") with open(os.path.join(saving_path + ".latest", "optim"), "wb") as fp: pickle.dump(self.optimizer, fp) self.last_train_loss = current_loss self.validate_and_save(saving_path, dev_data_iter) return step def validate_and_save(self, saving_path, dev_data_iter): with torch.no_grad(): model = self.model.module if hasattr(self.model, "module") else self.model model.eval() total_dev_loss, total_dev_tokens = 0, 0 for batch in dev_data_iter: mask, target, texts = mask_text(self.mask_prob, batch["pad_mask"], batch["texts"].clone(), model.text_processor) predictions = self.model(mask=mask, texts=texts, pads=batch["pad_mask"], langs=batch["langs"]) ntokens = target.size(0) if ntokens == 0: # Nothing to predict! continue loss = self.criterion(predictions, target).mean().data * ntokens total_dev_loss += float(loss) total_dev_tokens += ntokens dev_loss = total_dev_loss / total_dev_tokens print("Current dev loss", dev_loss) if self.best_dev_loss > float(dev_loss): self.best_dev_loss = float(dev_loss) print("saving best dev loss", self.best_dev_loss) model_to_save = (self.model.module if hasattr( self.model, "module") else self.model) model_to_save.save(saving_path) with open(os.path.join(saving_path, "optim"), "wb") as fp: pickle.dump(self.optimizer, fp) model.train() @staticmethod def config_dropout(model, dropout): model.encoder.config.hidden_dropout_prob = dropout model.encoder.config.attention_probs_dropout_prob = dropout @staticmethod def train(options): if not os.path.exists(options.model_path): os.makedirs(options.model_path) text_processor = TextProcessor(options.tokenizer_path) lm_class = ReformerLM if options.reformer else LM if options.pretrained_path is None: lm = lm_class(text_processor=text_processor, size=options.model_size) else: lm = lm_class.load(options.pretrained_path) if options.reformer: lm.config.hidden_dropout_prob = options.dropout lm.config.local_attention_probs_dropout_prob = options.dropout lm.config.lsh_attention_probs_dropout_prob = options.dropout else: LMTrainer.config_dropout(lm, options.dropout) train_data = dataset.TextDataset(save_cache_dir=options.train_path, max_cache_size=options.cache_size) dev_data = dataset.TextDataset(save_cache_dir=options.dev_path, max_cache_size=options.cache_size, load_all=True) if options.continue_train: with open(os.path.join(options.pretrained_path, "optim"), "rb") as fp: optimizer = pickle.load(fp) else: optimizer = build_optimizer(lm, options.learning_rate, options.warmup) trainer = LMTrainer(model=lm, mask_prob=options.mask_prob, optimizer=optimizer, clip=options.clip) collator = dataset.TextCollator(pad_idx=text_processor.pad_token_id()) train_sampler, dev_sampler = None, None pin_memory = torch.cuda.is_available() loader = data_utils.DataLoader(train_data, batch_size=options.batch, shuffle=False, pin_memory=pin_memory, collate_fn=collator, sampler=train_sampler) dev_loader = data_utils.DataLoader(dev_data, batch_size=options.batch, shuffle=False, pin_memory=pin_memory, collate_fn=collator, sampler=dev_sampler) step, train_epoch = 0, 1 while step <= options.step: print("train epoch", train_epoch) step = trainer.train_epoch(data_iter=loader, dev_data_iter=dev_loader, saving_path=options.model_path, step=step)
class BERTTrainer: """ BERTTrainer make the pretrained BERT model with two LM training method. 1. Masked Language Model : 3.3.1 Task #1: Masked LM 2. Next Sentence prediction : 3.3.2 Task #2: Next Sentence Prediction """ def __init__(self, model, vocab_size, train_dataloader, test_dataloader=None, lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000, with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, include_next=False, include_vision=True, total_epochs=1): """ :param bert: BERT model which you want to train :param vocab_size: total word vocab size :param train_dataloader: train dataset data loader :param test_dataloader: test dataset data loader [can be None] :param lr: learning rate of optimizer :param betas: Adam optimizer betas :param weight_decay: Adam optimizer weight decay param :param with_cuda: traning with cuda :param log_freq: logging frequency of the batch iteration """ # Setup cuda device for BERT training, argument -c, --cuda should be true cuda_condition = torch.cuda.is_available() and with_cuda self.device = torch.device("cuda:0" if cuda_condition else "cpu") n_gpu = torch.cuda.device_count() print("device", device, "n_gpu", n_gpu) # Initialize the BERT Language Model, with BERT model self.model = model.to(self.device) self.bert = self.model.bert self.padding_idx = 0 self.include_next = include_next self.include_vision = include_vision # Distributed GPU training if CUDA can detect more than 1 GPU if with_cuda and torch.cuda.device_count() > 1: print("Using %d GPUS for BERT" % torch.cuda.device_count()) #self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count())) self.model = DataParallelModel(self.model, device_ids=range( torch.cuda.device_count())) # Setting the train and test data loader self.train_data = train_dataloader self.test_data = test_dataloader # Setting the Adam optimizer with hyper-param self.optim = optim.Adamax(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) if self.model.__class__.__name__ in [ 'DataParallel', 'DataParallelModel' ]: self.optim_schedule = ScheduledOptim( self.optim, self.model.module.bert.transformer_hidden_size, n_warmup_steps=warmup_steps) else: self.optim_schedule = ScheduledOptim( self.optim, self.model.bert.transformer_hidden_size, n_warmup_steps=warmup_steps) # Using Negative Log Likelihood Loss function for predicting the masked_token self.criterion = nn.NLLLoss(ignore_index=0) if with_cuda and torch.cuda.device_count() > 1: print("Using %d GPUS for BERT" % torch.cuda.device_count()) #self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count())) self.criterion = DataParallelCriterion( self.criterion, device_ids=range(torch.cuda.device_count())) self.log_freq = log_freq self.total_iters = total_epochs * len(train_dataloader) print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()])) def train(self, epoch): self.iteration(epoch, self.train_data) def test(self, epoch): self.iteration(epoch, self.test_data, train=False) def iteration(self, epoch, data_loader, train=True): """ loop over the data_loader for training or testing if on train status, backward operation is activated and also auto save the model every peoch :param epoch: current epoch index :param data_loader: torch.utils.data.DataLoader for iteration :param train: boolean value of is train or test :return: None """ str_code = "train" if train else "test" # Setting the tqdm progress bar data_iter = tqdm.tqdm(enumerate(data_loader), desc="EP_%s:%d" % (str_code, epoch), total=len(data_loader), bar_format="{l_bar}{r_bar}", disable=True) avg_loss = 0.0 total_correct = 0 total_element = 0 for i, data in data_iter: # 0. prepare the text sequence tensor #data = {key: value.to(self.device) for key, value in data.items()} seq_tensor = data['masked_text_seq'] labels = data['masked_text_label'] seq_lengths = np.argmax(seq_tensor == self.padding_idx, axis=1) seq_lengths[seq_lengths == 0] = seq_tensor.shape[1] # Full length # Sort sequences by lengths seq_lengths, perm_idx = seq_lengths.sort(0, True) sorted_tensor = seq_tensor[perm_idx] mask = (sorted_tensor == padding_idx)[:, :seq_lengths[0]] f_t_all = data['feature_all'] isnext = data["isnext"] f_t_all = f_t_all[perm_idx] isnext = isnext[perm_idx] labels = labels[perm_idx] # 1. forward the next_sentence_prediction and masked_lm model if self.include_vision: #next_sent_output, mask_lm_output = self.model.forward(sorted_tensor.cuda(), mask.cuda(),seq_lengths.cuda(),f_t_all.cuda()) output = self.model.forward(sorted_tensor.cuda(), mask.cuda(), seq_lengths.cuda(), f_t_all.cuda()) length_output = len(output) print("You got %d outputs" % (length_output)) next_sent_output, mask_lm_output = zip(*output) print("vision test shape is %d " % (next_sent_output[1].shape)) print("lm test shape is %d " % (mask_lm_output[1].shape)) else: #next_sent_output, mask_lm_output = self.model.forward(sorted_tensor.cuda(), mask.cuda(),seq_lengths.cuda(),None) output = self.model.forward(sorted_tensor.cuda(), mask.cuda(), seq_lengths.cuda(), None) length_output = len(output) print("You got %d outputs" % (length_output)) next_sent_output, mask_lm_output = zip(*output) # 2-1. NLL(negative log likelihood) loss of is_next classification result next_loss = 0 if self.include_vision and self.include_next: next_loss = self.criterion(next_sent_output, isnext.cuda()) # 2-2. NLLLoss of predicting masked token word mask_loss = self.criterion(mask_lm_output.transpose(1, 2), labels[:, :seq_lengths[0]].cuda()) # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure #loss = next_loss + mask_loss # 3. backward and optimization only in train loss = loss.mean() if train: self.optim_schedule.zero_grad() loss.backward() self.optim_schedule.step_and_update_lr() # next vision prediction accuracy if self.include_next: correct = next_sent_output.argmax(dim=-1).eq( isnext.cuda()).sum().item() total_correct += correct total_element += data["isnext"].nelement() avg_loss += loss.item() if self.include_next: post_fix = { "epoch": epoch, "iter": i, "avg_loss": avg_loss / (i + 1), "avg_acc": total_correct / total_element * 100, "loss": loss.item() } else: post_fix = { "epoch": epoch, "iter": i, "avg_loss": avg_loss / (i + 1), "loss": loss.item() } #if i % self.log_freq == 0: # data_iter.write(str(post_fix)) if i % 100 == 0: #print("PROGRESS: {}%".format(round((myidx) * 100 / n_iters, 4))) print("\n") print("PROGRESS: {}%".format( round((epoch * len(data_loader) + i) * 100 / self.total_iters, 4))) print("EVALERR: {}%".format(avg_loss / (i + 1))) #print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter)) def save(self, epoch, file_path="pretrained_models/addbert_trained.model"): """ Saving the current BERT model on file_path :param epoch: current epoch number :param file_path: model output path which gonna be file_path+"ep%d" % epoch :return: final_output_path """ output_path = file_path + ".ep%d" % epoch torch.save(self.bert.cpu(), output_path) self.bert.to(self.device) print("EP:%d Model Saved on:" % epoch, output_path) return output_path
class Trainer: """ trainer class """ def __init__(self, cfg: Namespace, data: Dataset): """ Args: cfg: configuration data: train dataset """ self.cfg = cfg self.train, self.valid = data.split(0.8) RATING_FIELD.build_vocab(self.train) self.device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') # pylint: disable=no-member self.batch_size = cfg.batch_size if torch.cuda.is_available(): self.batch_size *= torch.cuda.device_count() self.trn_itr = BucketIterator( self.train, device=self.device, batch_size=self.batch_size, shuffle=True, train=True, sort_within_batch=True, sort_key=lambda exam: -len(exam.comment_text)) self.vld_itr = BucketIterator( self.valid, device=self.device, batch_size=self.batch_size, shuffle=False, train=False, sort_within_batch=True, sort_key=lambda exam: -len(exam.comment_text)) self.log_step = 1000 if len(self.vld_itr) < 100: self.log_step = 10 elif len(self.vld_itr) < 1000: self.log_step = 100 bert_path = cfg.bert_path if cfg.bert_path else 'bert-base-cased' self.model = BertForSequenceClassification.from_pretrained( bert_path, num_labels=2) pos_weight = ( len([exam for exam in self.train.examples if exam.target < 0.5]) / len([exam for exam in self.train.examples if exam.target >= 0.5])) pos_wgt_tensor = torch.tensor([1.0, pos_weight], device=self.device) # pylint: disable=not-callable self.criterion = nn.CrossEntropyLoss(weight=pos_wgt_tensor) if torch.cuda.is_available(): self.model = DataParallelModel(self.model.cuda()) self.criterion = DataParallelCriterion(self.criterion) self.optimizer = optim.Adam(self.model.parameters(), cfg.learning_rate) def run(self): """ do train """ max_f_score = -9e10 max_epoch = -1 for epoch in range(self.cfg.epoch): train_loss = self._train_epoch(epoch) metrics = self._evaluate(epoch) max_f_score_str = f' < {max_f_score:.2f}' if metrics['f_score'] > max_f_score: max_f_score_str = ' is max' max_f_score = metrics['f_score'] max_epoch = epoch torch.save(self.model.state_dict(), self.cfg.model_path) logging.info('EPOCH[%d]: train loss: %.6f, valid loss: %.6f, acc: %.2f,' \ ' F: %.2f%s', epoch, train_loss, metrics['loss'], metrics['accuracy'], metrics['f_score'], max_f_score_str) if (epoch - max_epoch) >= self.cfg.patience: logging.info('early stopping...') break logging.info('epoch: %d, f-score: %.2f', max_epoch, max_f_score) def _train_epoch(self, epoch: int) -> float: """ train single epoch Args: epoch: epoch number Returns: average loss """ self.model.train() progress = tqdm(self.trn_itr, f'EPOCH[{epoch}]', mininterval=1, ncols=100) losses = [] for step, batch in enumerate(progress, start=1): outputs = self.model(batch.comment_text) # output of model wrapped with DataParallelModel is a list of outputs from each GPU # make input of DataParallelCriterion as a list of tuples if isinstance(self.model, DataParallelModel): loss = self.criterion([(output, ) for output in outputs], batch.target) else: loss = self.criterion(outputs, batch.target) losses.append(loss.item()) if step % self.log_step == 0: avg_loss = sum(losses) / len(losses) progress.set_description(f'EPOCH[{epoch}] ({avg_loss:.6f})') loss.backward() self.optimizer.step() self.optimizer.zero_grad() return sum(losses) / len(losses) def _evaluate(self, epoch: int) -> Dict[str, float]: """ evaluate on validation data Args: epoch: epoch number Returns: metrics """ self.model.eval() progress = tqdm(self.vld_itr, f' EVAL[{epoch}]', mininterval=1, ncols=100) losses = [] preds = [] golds = [] for step, batch in enumerate(progress, start=1): with torch.no_grad(): outputs = self.model(batch.comment_text) if isinstance(self.model, DataParallelModel): loss = self.criterion([(output, ) for output in outputs], batch.target) for output in outputs: preds.extend([(0 if o[0] < o[1] else 1) for o in output]) else: loss = self.criterion(outputs, batch.target) preds.extend([(0 if output[0] < output[1] else 1) for output in outputs]) losses.append(loss.item()) golds.extend([gold.item() for gold in batch.target]) if step % self.log_step == 0: avg_loss = sum(losses) / len(losses) progress.set_description( f' EVAL[{epoch}] ({avg_loss:.6f})') metrics = self._get_metrics(preds, golds) metrics['loss'] = sum(losses) / len(losses) return metrics @classmethod def _get_metrics(cls, preds: List[float], golds: List[float]) -> Dict[str, float]: """ get metric values Args: preds: predictions golds: gold standards Returns: metric """ assert len(preds) == len(golds) true_pos = 0 false_pos = 0 false_neg = 0 true_neg = 0 for pred, gold in zip(preds, golds): if pred >= 0.5: if gold >= 0.5: true_pos += 1 else: false_pos += 1 else: if gold >= 0.5: false_neg += 1 else: true_neg += 1 accuracy = (true_pos + true_neg) / (true_pos + false_pos + false_neg + true_neg) precision = 0.0 if (true_pos + false_pos) > 0: precision = true_pos / (true_pos + false_pos) recall = 0.0 if (true_pos + false_neg) > 0: recall = true_pos / (true_pos + false_neg) f_score = 0.0 if (precision + recall) > 0.0: f_score = 2.0 * precision * recall / (precision + recall) return { 'accuracy': 100.0 * accuracy, 'precision': 100.0 * precision, 'recall': 100.0 * recall, 'f_score': 100.0 * f_score, }
class ImageMTTrainer: def __init__(self, model, mask_prob: float = 0.3, clip: int = 1, optimizer=None, beam_width: int = 5, max_len_a: float = 1.1, max_len_b: int = 5, len_penalty_ratio: float = 0.8, nll_loss: bool = False, fp16: bool = False, mm_mode="mixed", rank: int = -1): self.model = model self.clip = clip self.optimizer = optimizer self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.num_gpu = torch.cuda.device_count() self.mask_prob = mask_prob if nll_loss: self.criterion = nn.NLLLoss( ignore_index=model.text_processor.pad_token_id()) else: self.criterion = SmoothedNLLLoss( ignore_index=model.text_processor.pad_token_id()) self.num_gpu = torch.cuda.device_count() self.fp16 = False self.rank = rank if rank >= 0: self.device = torch.device('cuda', rank) torch.cuda.set_device(self.device) self.model = self.model.to(self.device) if fp16: self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O2") self.fp16 = True self.generator = BeamDecoder(self.model, beam_width=beam_width, max_len_a=max_len_a, max_len_b=max_len_b, len_penalty_ratio=len_penalty_ratio) if rank >= 0: self.model = DistributedDataParallel(self.model, device_ids=[self.rank], output_device=self.rank, find_unused_parameters=True) self.generator = DistributedDataParallel( self.generator, device_ids=[self.rank], output_device=self.rank, find_unused_parameters=True) elif self.num_gpu > 1: print("Let's use", self.num_gpu, "GPUs!") self.model = DataParallelModel(self.model) self.criterion = DataParallelCriterion(self.criterion) self.generator = DataParallelModel(self.generator) self.reference = None self.best_bleu = -1.0 self.mm_mode = mm_mode def train_epoch(self, img_data_iter: List[data_utils.DataLoader], step: int, saving_path: str = None, mass_data_iter: List[data_utils.DataLoader] = None, mt_dev_iter: List[data_utils.DataLoader] = None, mt_train_iter: List[data_utils.DataLoader] = None, max_step: int = 300000, accum=1, beam_width=1, fine_tune: bool = False, lang_directions: dict = False, lex_dict=None, save_opt: bool = False, **kwargs): "Standard Training and Logging Function" start = time.time() total_tokens, total_loss, tokens, cur_loss = 0, 0, 0, 0 cur_loss = 0 batch_zip, shortest = self.get_batch_zip(img_data_iter, mass_data_iter, mt_train_iter) model = (self.model.module if hasattr(self.model, "module") else self.model) self.optimizer.zero_grad() for i, batches in enumerate(batch_zip): for batch in batches: is_img_batch = isinstance(batch, list) and "captions" in batch[0] is_mass_batch = not is_img_batch and "dst_texts" not in batch is_contrastive = False try: if fine_tune and (is_img_batch or is_mass_batch): id2lid = lambda r: model.text_processor.languages[ model.text_processor.id2token(lang_directions[int( r)])] if is_mass_batch: src_inputs = batch["src_texts"].squeeze(0) src_pad_mask = src_inputs != model.text_processor.pad_token_id( ) pad_indices = batch["pad_idx"].squeeze(0) proposal = batch["proposal"].squeeze( 0) if lex_dict is not None else None target_langs = torch.LongTensor([ lang_directions[int(l)] for l in src_inputs[:, 0] ]) dst_langs = torch.LongTensor( [id2lid(l) for l in src_inputs[:, 0]]) else: src_inputs = [b["captions"] for b in batch] src_pad_mask = [b["caption_mask"] for b in batch] pad_indices = [b["pad_idx"] for b in batch] proposal = [ b["proposal"] if lex_dict is not None else None for b in batch ] target_langs = [ torch.LongTensor([ lang_directions[int(l)] for l in src[:, 0] ]) for src in src_inputs ] dst_langs = [ torch.LongTensor( [id2lid(l) for l in src[:, 0]]) for src in src_inputs ] if len(src_inputs) < self.num_gpu: continue if is_mass_batch: langs = batch["langs"].squeeze(0) else: langs = [b["langs"] for b in batch] model.eval() with torch.no_grad(): # We do not backpropagate the data generator following the MASS paper. images = None if is_img_batch: images = [b["images"] for b in batch] outputs = self.generator( src_inputs=src_inputs, src_sizes=pad_indices, first_tokens=target_langs, src_langs=langs, tgt_langs=dst_langs, pad_idx=model.text_processor.pad_token_id(), src_mask=src_pad_mask, unpad_output=False, beam_width=beam_width, images=images, proposals=proposal) if self.num_gpu > 1 and self.rank < 0: if is_mass_batch: new_outputs = [] for output in outputs: new_outputs += output outputs = new_outputs if is_mass_batch or self.num_gpu <= 1: translations = pad_sequence( outputs, batch_first=True, padding_value=model.text_processor. pad_token_id()) translation_proposals = None if lex_dict is not None: translation_proposals = list( map( lambda o: dataset. get_lex_suggestions( lex_dict, o, model.text_processor. pad_token_id()), outputs)) translation_proposals = pad_sequence( translation_proposals, batch_first=True, padding_value=model.text_processor. pad_token_id()) translation_pad_mask = ( translations != model.text_processor.pad_token_id()) else: translation_proposals = None if lex_dict is not None: translation_proposals = [ pad_sequence( list( map( lambda o: dataset. get_lex_suggestions( lex_dict, o, model.text_processor. pad_token_id()), output)), batch_first=True, padding_value=model.text_processor. pad_token_id()) for output in outputs ] translations = [ pad_sequence(output, batch_first=True, padding_value=model. text_processor.pad_token_id()) for output in outputs ] translation_pad_mask = [ t != model.text_processor.pad_token_id() for t in translations ] model.train() if is_mass_batch: langs = batch["langs"].squeeze(0) else: langs = torch.cat([b["langs"] for b in batch]) # Now use it for back-translation loss. predictions = model( src_inputs=translations, tgt_inputs=src_inputs, src_pads=translation_pad_mask, pad_idx=model.text_processor.pad_token_id(), src_langs=dst_langs, tgt_langs=langs, proposals=translation_proposals, log_softmax=True) if is_mass_batch: src_targets = src_inputs[:, 1:].contiguous().view(-1) src_mask_flat = src_pad_mask[:, 1:].contiguous().view( -1) else: src_targets = torch.cat( list(map(lambda s: s[:, 1:], src_inputs))) src_mask_flat = torch.cat( list(map(lambda s: s[:, 1:], src_pad_mask))) targets = src_targets[src_mask_flat] ntokens = targets.size(0) elif is_img_batch: src_inputs = [b["captions"] for b in batch] src_pad_mask = [b["caption_mask"] for b in batch] proposals = [b["proposal"] for b in batch ] if lex_dict is not None else None langs = [b["langs"] for b in batch] if (self.mm_mode == "mixed" and random.random() <= .5 ) or self.mm_mode == "masked": pad_indices = [b["pad_idx"] for b in batch] if len(batch) < self.num_gpu: continue # For image masking, we are allowed to mask more than mask_prob mask_prob = random.uniform(self.mask_prob, 1.0) masked_info = list( map( lambda pi, si: mass_mask( mask_prob, pi, si, model.text_processor ), pad_indices, src_inputs)) predictions = self.model( src_inputs=list( map(lambda m: m["src_text"], masked_info)), tgt_inputs=list( map(lambda m: m["to_recover"], masked_info)), tgt_positions=list( map(lambda m: m["positions"], masked_info)), src_pads=src_pad_mask, pad_idx=model.text_processor.pad_token_id(), src_langs=langs, batch=batch, proposals=proposals, log_softmax=True) targets = torch.cat( list(map(lambda m: m["targets"], masked_info))) ntokens = targets.size(0) else: neg_samples = [b["neg"] for b in batch] neg_mask = [b["neg_mask"] for b in batch] loss = self.model( src_inputs=src_inputs, src_pads=src_pad_mask, neg_samples=neg_samples, neg_mask=neg_mask, pad_idx=model.text_processor.pad_token_id(), src_langs=langs, batch=batch, proposals=proposals, log_softmax=True) is_contrastive = True elif not is_mass_batch: # MT data src_inputs = batch["src_texts"].squeeze(0) src_mask = batch["src_pad_mask"].squeeze(0) tgt_inputs = batch["dst_texts"].squeeze(0) tgt_mask = batch["dst_pad_mask"].squeeze(0) src_langs = batch["src_langs"].squeeze(0) dst_langs = batch["dst_langs"].squeeze(0) proposals = batch["proposal"].squeeze( 0) if lex_dict is not None else None if src_inputs.size(0) < self.num_gpu: continue predictions = self.model( src_inputs=src_inputs, tgt_inputs=tgt_inputs, src_pads=src_mask, tgt_mask=tgt_mask, src_langs=src_langs, tgt_langs=dst_langs, proposals=proposals, pad_idx=model.text_processor.pad_token_id(), log_softmax=True) targets = tgt_inputs[:, 1:].contiguous().view(-1) tgt_mask_flat = tgt_mask[:, 1:].contiguous().view(-1) targets = targets[tgt_mask_flat] ntokens = targets.size(0) else: # MASS data src_inputs = batch["src_texts"].squeeze(0) pad_indices = batch["pad_idx"].squeeze(0) proposals = batch["proposal"].squeeze( 0) if lex_dict is not None else None if src_inputs.size(0) < self.num_gpu: continue masked_info = mass_mask(self.mask_prob, pad_indices, src_inputs, model.text_processor) predictions = self.model( src_inputs=masked_info["src_text"], tgt_inputs=masked_info["to_recover"], tgt_positions=masked_info["positions"], pad_idx=model.text_processor.pad_token_id(), src_langs=batch["langs"].squeeze(0), proposals=proposals, log_softmax=True) targets = masked_info["targets"] ntokens = targets.size(0) if is_contrastive: # Nothing to predict! backward(loss, self.optimizer, self.fp16) loss = loss.data elif ntokens > 0: if self.num_gpu == 1: targets = targets.to(predictions.device) if self.rank >= 0: targets = targets.to(self.device) loss = self.criterion(predictions, targets).mean() backward(loss, self.optimizer, self.fp16) loss = float(loss.data) * ntokens tokens += ntokens total_tokens += ntokens total_loss += loss cur_loss += loss torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) step += 1 if step % accum == 0: self.optimizer.step() self.optimizer.zero_grad() if is_mass_batch and not fine_tune: mass_unmask(masked_info["src_text"], masked_info["src_mask"], masked_info["mask_idx"]) if not is_contrastive and is_img_batch and not fine_tune: map( lambda m: mass_unmask(m["src_text"], m["src_mask"], m["mask_idx"]), masked_info) if step % 50 == 0 and tokens > 0: elapsed = time.time() - start print( self.rank, "->", datetime.datetime.now(), "Epoch Step: %d Loss: %f Tokens per Sec: %f " % (step, cur_loss / tokens, tokens / elapsed)) if mt_dev_iter is not None and step % 5000 == 0 and self.rank <= 0: bleu = self.eval_bleu(mt_dev_iter, saving_path) print("BLEU:", bleu) if step % 10000 == 0: if self.rank <= 0: if self.rank < 0: model.cpu().save(saving_path + ".latest") elif self.rank == 0: model.save(saving_path + ".latest") if save_opt: with open( os.path.join( saving_path + ".latest", "optim"), "wb") as fp: pickle.dump(self.optimizer, fp) if self.rank < 0: model = model.to(self.device) start, tokens, cur_loss = time.time(), 0, 0 except RuntimeError as err: print(repr(err)) print("Error processing", is_img_batch) if (isinstance(model, ImageMassSeq2Seq)) and is_img_batch: for b in batch: print("->", len(b["images"]), b["captions"].size()) torch.cuda.empty_cache() if i == shortest - 1: break if step >= max_step: break try: if self.rank <= 0: print("Total loss in this epoch: %f" % (total_loss / total_tokens)) if self.rank < 0: model.cpu().save(saving_path + ".latest") model = model.to(self.device) elif self.rank == 0: model.save(saving_path + ".latest") if mt_dev_iter is not None: bleu = self.eval_bleu(mt_dev_iter, saving_path) print("BLEU:", bleu) except RuntimeError as err: print(repr(err)) return step def get_batch_zip(self, img_data_iter, mass_data_iter, mt_train_iter): # if img_data_iter is not None and mt_train_iter is not None: # img_data_iter *= 5 # if mass_data_iter is not None and mt_train_iter is not None: # mass_data_iter *= 5 iters = list( chain(*filter(lambda x: x != None, [img_data_iter, mass_data_iter, mt_train_iter]))) shortest = min(len(l) for l in iters) return zip(*iters), shortest def eval_bleu(self, dev_data_iter, saving_path, save_opt: bool = False): mt_output = [] src_text = [] model = (self.model.module if hasattr(self.model, "module") else self.model) model.eval() with torch.no_grad(): for iter in dev_data_iter: for batch in iter: src_inputs = batch["src_texts"].squeeze(0) src_mask = batch["src_pad_mask"].squeeze(0) tgt_inputs = batch["dst_texts"].squeeze(0) src_langs = batch["src_langs"].squeeze(0) dst_langs = batch["dst_langs"].squeeze(0) src_pad_idx = batch["pad_idx"].squeeze(0) proposal = batch["proposal"].squeeze( 0) if batch["proposal"] is not None else None src_ids = get_outputs_until_eos( model.text_processor.sep_token_id(), src_inputs, remove_first_token=True) src_text += list( map( lambda src: model.text_processor.tokenizer.decode( src.numpy()), src_ids)) outputs = self.generator( src_inputs=src_inputs, src_sizes=src_pad_idx, first_tokens=tgt_inputs[:, 0], src_mask=src_mask, src_langs=src_langs, tgt_langs=dst_langs, pad_idx=model.text_processor.pad_token_id(), proposals=proposal) if self.num_gpu > 1 and self.rank < 0: new_outputs = [] for output in outputs: new_outputs += output outputs = new_outputs mt_output += list( map( lambda x: model.text_processor.tokenizer.decode(x[ 1:].numpy()), outputs)) model.train() bleu = sacrebleu.corpus_bleu(mt_output, [self.reference[:len(mt_output)]], lowercase=True, tokenize="intl") with open(os.path.join(saving_path, "bleu.output"), "w") as writer: writer.write("\n".join([ src + "\n" + ref + "\n" + o + "\n\n***************\n" for src, ref, o in zip(src_text, mt_output, self.reference[:len(mt_output)]) ])) if bleu.score > self.best_bleu: self.best_bleu = bleu.score print("Saving best BLEU", self.best_bleu) with open(os.path.join(saving_path, "bleu.best.output"), "w") as writer: writer.write("\n".join([ src + "\n" + ref + "\n" + o + "\n\n***************\n" for src, ref, o in zip(src_text, mt_output, self.reference[:len(mt_output)]) ])) if self.rank < 0: model.cpu().save(saving_path) model = model.to(self.device) elif self.rank == 0: model.save(saving_path) if save_opt: with open(os.path.join(saving_path, "optim"), "wb") as fp: pickle.dump(self.optimizer, fp) return bleu.score @staticmethod def train(options): lex_dict = None if options.dict_path is not None: lex_dict = get_lex_dict(options.dict_path) if options.local_rank <= 0 and not os.path.exists(options.model_path): os.makedirs(options.model_path) text_processor = TextProcessor(options.tokenizer_path) assert text_processor.pad_token_id() == 0 num_processors = max(torch.cuda.device_count(), 1) if options.local_rank < 0 else 1 if options.pretrained_path is not None: mt_model = Seq2Seq.load(ImageMassSeq2Seq, options.pretrained_path, tok_dir=options.tokenizer_path) else: mt_model = ImageMassSeq2Seq( use_proposals=lex_dict is not None, tie_embed=options.tie_embed, text_processor=text_processor, resnet_depth=options.resnet_depth, lang_dec=options.lang_decoder, enc_layer=options.encoder_layer, dec_layer=options.decoder_layer, embed_dim=options.embed_dim, intermediate_dim=options.intermediate_layer_dim) if options.lm_path is not None: lm = LM(text_processor=text_processor, enc_layer=options.encoder_layer, embed_dim=options.embed_dim, intermediate_dim=options.intermediate_layer_dim) mt_model.init_from_lm(lm) print("Model initialization done!") # We assume that the collator function returns a list with the size of number of gpus (in case of cpus, collator = dataset.ImageTextCollator() num_batches = max(1, torch.cuda.device_count()) if options.continue_train: with open(os.path.join(options.pretrained_path, "optim"), "rb") as fp: optimizer = pickle.load(fp) else: optimizer = build_optimizer(mt_model, options.learning_rate, warump_steps=options.warmup) trainer = ImageMTTrainer(model=mt_model, mask_prob=options.mask_prob, optimizer=optimizer, clip=options.clip, beam_width=options.beam_width, max_len_a=options.max_len_a, max_len_b=options.max_len_b, len_penalty_ratio=options.len_penalty_ratio, fp16=options.fp16, mm_mode=options.mm_mode, rank=options.local_rank) pin_memory = torch.cuda.is_available() img_train_loader = ImageMTTrainer.get_img_loader( collator, dataset.ImageCaptionDataset, options.train_path, mt_model, num_batches, options, pin_memory, lex_dict=lex_dict) mass_train_data, mass_train_loader, finetune_loader, mt_dev_loader = None, None, None, None if options.mass_train_path is not None: mass_train_paths = options.mass_train_path.strip().split(",") if options.step > 0: mass_train_data, mass_train_loader = ImageMTTrainer.get_mass_loader( mass_train_paths, mt_model, num_processors, options, pin_memory, keep_examples=options.finetune_step > 0, lex_dict=lex_dict) if options.finetune_step > 0: finetune_loader, finetune_data = ImageMTTrainer.get_mass_finetune_data( mass_train_data, mass_train_paths, mt_model, num_processors, options, pin_memory, lex_dict=lex_dict) mt_train_loader = None if options.mt_train_path is not None: mt_train_loader = ImageMTTrainer.get_mt_train_data( mt_model, num_processors, options, pin_memory, lex_dict=lex_dict) mt_dev_loader = None if options.mt_dev_path is not None: mt_dev_loader = ImageMTTrainer.get_mt_dev_data(mt_model, options, pin_memory, text_processor, trainer, lex_dict=lex_dict) step, train_epoch = 0, 1 while options.step > 0 and step < options.step: print("train epoch", train_epoch) step = trainer.train_epoch(img_data_iter=img_train_loader, mass_data_iter=mass_train_loader, mt_train_iter=mt_train_loader, max_step=options.step, lex_dict=lex_dict, mt_dev_iter=mt_dev_loader, saving_path=options.model_path, step=step, save_opt=options.save_opt, accum=options.accum) train_epoch += 1 finetune_epoch = 0 # Resetting the optimizer for the purpose of finetuning. trainer.optimizer.reset() lang_directions = ImageMTTrainer.get_lang_dirs(options.bt_langs, text_processor) print(options.local_rank, "lang dirs", lang_directions) print(options.local_rank, "Reloading image train data with new batch size...") if options.finetune_step > 0 and img_train_loader is not None: img_train_loader = ImageMTTrainer.get_img_loader( collator, dataset.ImageCaptionDataset, options.train_path, mt_model, num_batches, options, pin_memory, denom=2, lex_dict=lex_dict) if options.ignore_mt_mass: mt_train_loader = None print(options.local_rank, "Reloading image train data with new batch size done!") while options.finetune_step > 0 and step <= options.finetune_step + options.step: print(options.local_rank, "finetune epoch", finetune_epoch) step = trainer.train_epoch(img_data_iter=img_train_loader, mass_data_iter=finetune_loader, mt_train_iter=mt_train_loader, max_step=options.finetune_step + options.step, mt_dev_iter=mt_dev_loader, saving_path=options.model_path, step=step, fine_tune=True, lang_directions=lang_directions, lex_dict=lex_dict, save_opt=options.save_opt, accum=options.accum, beam_width=options.bt_beam_width) finetune_epoch += 1 @staticmethod def get_lang_dirs(bt_langs, text_processor: TextProcessor): langs = ["<" + l + ">" for l in bt_langs.strip().split(",")] langs = set([text_processor.token_id(l) for l in langs]) if len(langs) < 2: return None assert len(langs) <= 2 lang_directions = {} for lang1 in langs: for lang2 in langs: if lang1 != lang2: # Assuming that we only have two languages! lang_directions[lang1] = lang2 return lang_directions @staticmethod def get_mt_dev_data(mt_model, options, pin_memory, text_processor, trainer, lex_dict=None): mt_dev_loader = [] dev_paths = options.mt_dev_path.split(",") trainer.reference = [] for dev_path in dev_paths: mt_dev_data = dataset.MTDataset( batch_pickle_dir=dev_path, max_batch_capacity=options.total_capacity, keep_pad_idx=True, max_batch=int(options.batch / (options.beam_width * 2)), pad_idx=mt_model.text_processor.pad_token_id(), lex_dict=lex_dict) dl = data_utils.DataLoader(mt_dev_data, batch_size=1, shuffle=False, pin_memory=pin_memory) mt_dev_loader.append(dl) print(options.local_rank, "creating reference") generator = (trainer.generator.module if hasattr( trainer.generator, "module") else trainer.generator) for batch in dl: tgt_inputs = batch["dst_texts"].squeeze() refs = get_outputs_until_eos(text_processor.sep_token_id(), tgt_inputs, remove_first_token=True) ref = [ generator.seq2seq_model.text_processor.tokenizer.decode( ref.numpy()) for ref in refs ] trainer.reference += ref return mt_dev_loader @staticmethod def get_mt_train_data(mt_model, num_processors, options, pin_memory, lex_dict=None): mt_train_loader = [] train_paths = options.mt_train_path.split(",") for train_path in train_paths: mt_train_data = dataset.MTDataset( batch_pickle_dir=train_path, max_batch_capacity=int(num_processors * options.total_capacity / 2), max_batch=int(num_processors * options.batch / 2), pad_idx=mt_model.text_processor.pad_token_id(), lex_dict=lex_dict, keep_pad_idx=False) mtl = data_utils.DataLoader( mt_train_data, sampler=None if options.local_rank < 0 else DistributedSampler( mt_train_data, rank=options.local_rank), batch_size=1, shuffle=(options.local_rank < 0), pin_memory=pin_memory) mt_train_loader.append(mtl) return mt_train_loader @staticmethod def get_mass_finetune_data(mass_train_data, mass_train_paths, mt_model, num_processors, options, pin_memory, lex_dict=None): finetune_data, finetune_loader = [], [] for i, mass_train_path in enumerate(mass_train_paths): fd = dataset.MassDataset( batch_pickle_dir=mass_train_path, max_batch_capacity=int(num_processors * options.total_capacity / max(2, options.bt_beam_width)), max_batch=int(num_processors * options.batch / max(2, options.bt_beam_width)), pad_idx=mt_model.text_processor.pad_token_id(), max_seq_len=options.max_seq_len, keep_examples=False, example_list=None if mass_train_data is None else mass_train_data[i].examples_list, lex_dict=lex_dict) finetune_data.append(fd) fl = data_utils.DataLoader( fd, sampler=None if options.local_rank < 0 else DistributedSampler( fd, rank=options.local_rank), batch_size=1, shuffle=(options.local_rank < 0), pin_memory=pin_memory) finetune_loader.append(fl) if mass_train_data is not None: mass_train_data[i].examples_list = [] return finetune_loader, finetune_data @staticmethod def get_mass_loader(mass_train_paths, mt_model, num_processors, options, pin_memory, keep_examples, lex_dict=None): mass_train_data, mass_train_loader = [], [] for i, mass_train_path in enumerate(mass_train_paths): td = dataset.MassDataset( batch_pickle_dir=mass_train_path, max_batch_capacity=num_processors * options.total_capacity, max_batch=num_processors * options.batch, pad_idx=mt_model.text_processor.pad_token_id(), max_seq_len=options.max_seq_len, keep_examples=keep_examples, lex_dict=lex_dict) mass_train_data.append(td) dl = data_utils.DataLoader( td, sampler=None if options.local_rank < 0 else DistributedSampler( td, rank=options.local_rank), batch_size=1, shuffle=(options.local_rank < 0), pin_memory=pin_memory) mass_train_loader.append(dl) return mass_train_data, mass_train_loader @staticmethod def get_img_loader(collator, dataset_class, paths, mt_model, num_batches, options, pin_memory, denom=1, lex_dict=None, shuffle=True): if paths is not None: img_loader = [] for pth in paths.strip().split(","): data = dataset_class( root_img_dir=options.image_dir, data_bin_file=pth, max_capacity=int(options.img_capacity / denom), text_processor=mt_model.text_processor, max_img_per_batch=options.max_image / denom, lex_dict=lex_dict) print(options.local_rank, pth, "Length of training data", len(data)) tl = data_utils.DataLoader( data, sampler=None if options.local_rank < 0 else DistributedSampler(data, rank=options.local_rank), batch_size=num_batches, shuffle=shuffle, pin_memory=pin_memory, collate_fn=collator) img_loader.append(tl) return img_loader return None
def main_tr(args, crossVal): dataLoad = ld.LoadData(args.data_dir, args.classes) data = dataLoad.processData(crossVal, args.data_name) # load the model model = net.MiniSeg(args.classes, aux=True) if not osp.isdir(osp.join(args.savedir + '_mod' + str(args.max_epochs))): os.mkdir(args.savedir + '_mod' + str(args.max_epochs)) if not osp.isdir( osp.join(args.savedir + '_mod' + str(args.max_epochs), args.data_name)): os.mkdir( osp.join(args.savedir + '_mod' + str(args.max_epochs), args.data_name)) saveDir = args.savedir + '_mod' + str( args.max_epochs) + '/' + args.data_name + '/' + args.model_name # create the directory if not exist if not osp.exists(saveDir): os.mkdir(saveDir) if args.gpu and torch.cuda.device_count() > 1: #model = torch.nn.DataParallel(model) model = DataParallelModel(model) if args.gpu: model = model.cuda() total_paramters = sum([np.prod(p.size()) for p in model.parameters()]) print('Total network parameters: ' + str(total_paramters)) # define optimization criteria weight = torch.from_numpy( data['classWeights']) # convert the numpy array to torch if args.gpu: weight = weight.cuda() criteria = CrossEntropyLoss2d(weight, args.ignore_label) #weight if args.gpu and torch.cuda.device_count() > 1: criteria = DataParallelCriterion(criteria) if args.gpu: criteria = criteria.cuda() # compose the data with transforms trainDataset_main = myTransforms.Compose([ myTransforms.Normalize(mean=data['mean'], std=data['std']), myTransforms.Scale(args.width, args.height), myTransforms.RandomCropResize(int(32. / 1024. * args.width)), myTransforms.RandomFlip(), myTransforms.ToTensor() ]) trainDataset_scale1 = myTransforms.Compose([ myTransforms.Normalize(mean=data['mean'], std=data['std']), myTransforms.Scale(int(args.width * 1.5), int(args.height * 1.5)), myTransforms.RandomCropResize(int(100. / 1024. * args.width)), myTransforms.RandomFlip(), myTransforms.ToTensor() ]) trainDataset_scale2 = myTransforms.Compose([ myTransforms.Normalize(mean=data['mean'], std=data['std']), myTransforms.Scale(int(args.width * 1.25), int(args.height * 1.25)), myTransforms.RandomCropResize(int(100. / 1024. * args.width)), myTransforms.RandomFlip(), myTransforms.ToTensor() ]) trainDataset_scale3 = myTransforms.Compose([ myTransforms.Normalize(mean=data['mean'], std=data['std']), myTransforms.Scale(int(args.width * 0.75), int(args.height * 0.75)), myTransforms.RandomCropResize(int(32. / 1024. * args.width)), myTransforms.RandomFlip(), myTransforms.ToTensor() ]) valDataset = myTransforms.Compose([ myTransforms.Normalize(mean=data['mean'], std=data['std']), myTransforms.Scale(args.width, args.height), myTransforms.ToTensor() ]) # since we training from scratch, we create data loaders at different scales # so that we can generate more augmented data and prevent the network from overfitting trainLoader = torch.utils.data.DataLoader(myDataLoader.Dataset( data['trainIm'], data['trainAnnot'], transform=trainDataset_main), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) trainLoader_scale1 = torch.utils.data.DataLoader( myDataLoader.Dataset(data['trainIm'], data['trainAnnot'], transform=trainDataset_scale1), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) trainLoader_scale2 = torch.utils.data.DataLoader( myDataLoader.Dataset(data['trainIm'], data['trainAnnot'], transform=trainDataset_scale2), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) trainLoader_scale3 = torch.utils.data.DataLoader( myDataLoader.Dataset(data['trainIm'], data['trainAnnot'], transform=trainDataset_scale3), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) valLoader = torch.utils.data.DataLoader(myDataLoader.Dataset( data['valIm'], data['valAnnot'], transform=valDataset), batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) max_batches = len(trainLoader) + len(trainLoader_scale1) + len( trainLoader_scale2) + len(trainLoader_scale3) if args.gpu: cudnn.benchmark = True start_epoch = 0 if args.pretrained is not None: state_dict = torch.load(args.pretrained) new_keys = [] new_values = [] for idx, key in enumerate(state_dict.keys()): if 'pred' not in key: new_keys.append(key) new_values.append(list(state_dict.values())[idx]) new_dict = OrderedDict(list(zip(new_keys, new_values))) model.load_state_dict(new_dict, strict=False) print('pretrained model loaded') if args.resume is not None: if osp.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] args.lr = checkpoint['lr'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) log_file = osp.join(saveDir, 'trainValLog_' + args.model_name + '.txt') if osp.isfile(log_file): logger = open(log_file, 'a') else: logger = open(log_file, 'w') logger.write("Parameters: %s" % (str(total_paramters))) logger.write("\n%s\t%s\t\t%s\t%s\t%s\t%s\tlr" % ('CrossVal', 'Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val)')) logger.flush() optimizer = torch.optim.Adam(model.parameters(), args.lr, (0.9, 0.999), eps=1e-08, weight_decay=1e-4) maxmIOU = 0 maxEpoch = 0 print(args.model_name + '-CrossVal: ' + str(crossVal + 1)) for epoch in range(start_epoch, args.max_epochs): # train for one epoch cur_iter = 0 train(args, trainLoader_scale1, model, criteria, optimizer, epoch, max_batches, cur_iter) cur_iter += len(trainLoader_scale1) train(args, trainLoader_scale2, model, criteria, optimizer, epoch, max_batches, cur_iter) cur_iter += len(trainLoader_scale2) train(args, trainLoader_scale3, model, criteria, optimizer, epoch, max_batches, cur_iter) cur_iter += len(trainLoader_scale3) lossTr, overall_acc_tr, per_class_acc_tr, per_class_iu_tr, mIOU_tr, lr = \ train(args, trainLoader, model, criteria, optimizer, epoch, max_batches, cur_iter) # evaluate on validation set lossVal, overall_acc_val, per_class_acc_val, per_class_iu_val, mIOU_val = \ val(args, valLoader, model, criteria) torch.save( { 'epoch': epoch + 1, 'arch': str(model), 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lossTr': lossTr, 'lossVal': lossVal, 'iouTr': mIOU_tr, 'iouVal': mIOU_val, 'lr': lr }, osp.join( saveDir, 'checkpoint_' + args.model_name + '_crossVal' + str(crossVal + 1) + '.pth.tar')) # save the model also model_file_name = osp.join( saveDir, 'model_' + args.model_name + '_crossVal' + str(crossVal + 1) + '_' + str(epoch + 1) + '.pth') torch.save(model.state_dict(), model_file_name) logger.write( "\n%d\t\t%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" % (crossVal + 1, epoch + 1, lossTr, lossVal, mIOU_tr, mIOU_val, lr)) logger.flush() print("\nEpoch No. %d:\tTrain Loss = %.4f\tVal Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f\n" \ % (epoch + 1, lossTr, lossVal, mIOU_tr, mIOU_val)) if mIOU_val >= maxmIOU: maxmIOU = mIOU_val maxEpoch = epoch + 1 torch.cuda.empty_cache() logger.flush() logger.close() return maxEpoch, maxmIOU