def __call__( self, net: nn.Module, train_iter: DataLoader, validation_iter: Optional[DataLoader] = None, ) -> None: wandb.watch(net, log="all", log_freq=self.num_batches_per_epoch) optimizer = Adam(net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) lr_scheduler = OneCycleLR( optimizer, max_lr=self.maximum_learning_rate, steps_per_epoch=self.num_batches_per_epoch, epochs=self.epochs, ) for epoch_no in range(self.epochs): # mark epoch start time tic = time.time() avg_epoch_loss = 0.0 with tqdm(train_iter) as it: for batch_no, data_entry in enumerate(it, start=1): optimizer.zero_grad() inputs = [v.to(self.device) for v in data_entry.values()] output = net(*inputs) if isinstance(output, (list, tuple)): loss = output[0] else: loss = output avg_epoch_loss += loss.item() it.set_postfix( ordered_dict={ "avg_epoch_loss": avg_epoch_loss / batch_no, "epoch": epoch_no, }, refresh=False, ) wandb.log({"loss": loss.item()}) loss.backward() if self.clip_gradient is not None: nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient) optimizer.step() lr_scheduler.step() if self.num_batches_per_epoch == batch_no: break # mark epoch end time and log time cost of current epoch toc = time.time()
def fit(self, learning_rate: Tuple[float, float]): # Capture learning errors self.train_val_error = {"train": [], "validation": [], "lr": []} self._init_model( model=self.model_, optimizer=self.optimizer_, criterion=self.criterion_ ) # Setup one cycle policy scheduler = OneCycleLR( optimizer=self.optimizer, max_lr=learning_rate, steps_per_epoch=len(self.train_loader), epochs=self.n_epochs, anneal_strategy="cos", ) # Iterate over epochs for epoch in range(self.n_epochs): # Training set self.model.train() train_loss = 0 for batch_num, samples in enumerate(self.train_loader): # Forward pass, get loss loss = self._forward_pass(samples=samples) train_loss += loss.item() # Zero gradients, perform a backward pass, and update the weights. self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update scheduler self.train_val_error["lr"].append(scheduler.get_lr()[0]) # One cycle scheduler must be called per batch # https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.OneCycleLR scheduler.step() # Append train loss per current epoch train_err = train_loss / batch_num self.train_val_error["train"].append(train_err) # Validation set self.model.eval() validation_loss = 0 for batch_num, samples in enumerate(self.valid_loader): # Forward pass, get loss loss = self._forward_pass(samples=samples) validation_loss += loss.item() # Append validation loss per current epoch val_err = validation_loss / batch_num self.train_val_error["validation"].append(val_err) return pd.DataFrame(data={ 'Train error' : self.train_val_error['train'], 'Validation error': self.train_val_error['validation'] })
def train(model, device, train_loader, optimizer, epoch): model.train() pbar = tqdm(train_loader) correct = 0 processed = 0 lambda_l1 = 0.01 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) scheduler = OneCycleLR(optimizer, max_lr=0.020, epochs=20, steps_per_epoch=len(train_loader)) for batch_idx, (data, target) in enumerate(pbar): # get samples data, target = data.to(device), target.to(device) # Init optimizer.zero_grad() # In PyTorch, we need to set the gradients to zero before starting to do backpropragation because PyTorch accumulates the gradients on subsequent backward passes. # Because of this, when you start your training loop, ideally you should zero out the gradients so that you do the parameter update correctly. # Predict y_pred = model(data) # Calculate loss loss = F.nll_loss(y_pred, target) train_losses.append(loss) l1 = 0 for p in model.parameters(): l1 += p.abs().sum() #print("l1 at 1st epoch: ", l1) loss = loss + lambda_l1 * l1 # Backpropagation loss.backward() optimizer.step() scheduler.step() # Update pbar-tqdm pred = y_pred.argmax( dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() processed += len(data) pbar.set_description( desc= f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100*correct/processed:0.2f}' ) train_acc.append(100 * correct / processed) return train_losses, train_acc
def init_train(self, con_weight: float = 1.0): test_img = self.get_test_image() meter = AverageMeter("Loss") self.writer.flush() lr_scheduler = OneCycleLR(self.optimizer_G, max_lr=0.9999, steps_per_epoch=len(self.dataloader), epochs=self.init_train_epoch) for g in self.optimizer_G.param_groups: g['lr'] = self.init_lr for epoch in tqdm(range(self.init_train_epoch)): meter.reset() for i, (style, smooth, train) in enumerate(self.dataloader, 0): # train = transform(test_img).unsqueeze(0) self.G.zero_grad(set_to_none=self.grad_set_to_none) train = train.to(self.device) generator_output = self.G(train) # content_loss = loss.reconstruction_loss(generator_output, train) * con_weight content_loss = self.loss.content_loss(generator_output, train) * con_weight # content_loss = F.mse_loss(train, generator_output) * con_weight content_loss.backward() self.optimizer_G.step() lr_scheduler.step() meter.update(content_loss.detach()) self.writer.add_scalar(f"Loss : {self.init_time}", meter.sum.item(), epoch) self.write_weights(epoch + 1, write_D=False) self.eval_image(epoch, f'{self.init_time} reconstructed img', test_img) for g in self.optimizer_G.param_groups: g['lr'] = self.G_lr
class Regularizations: @staticmethod def dropout(dropout_value): return nn.Dropout(int(dropout_value)) def __init__(self, optim_type, model, lr, momentum, max_lr, len_loader, weight_decay=0): self.optimizer = getattr(optim, optim_type)(getattr(model, 'parameters')(), lr=lr, momentum=momentum, weight_decay=weight_decay) self.scheduler = OneCycleLR(self.optimizer, max_lr=max_lr, steps_per_epoch=len_loader, epochs=50, div_factor=10, final_div_factor=1, pct_start=10 / 50) def loss_function(self, loss_type, preds, targets): return getattr(F, loss_type)(preds, targets) def optimizer_step(self, loss=False, step=0): if step == 0: self.optimizer.zero_grad() loss.backward() self.optimizer.step() elif step != 0: self.scheduler.step()
def _train(start_iteration, model, optimizer, device, train_dataloader, test_dataloader, args): train_loss = deque(maxlen=args.log_freq) test_loss = deque(maxlen=args.log_freq) model = model.to(device) start_time = time.perf_counter() test_iter = iter(test_dataloader) train_iter = iter(train_dataloader) loss_func = partial(_loss_func, model=model, device=device) oclr = OneCycleLR(optimizer, args.learning_rate, pct_start=0.01, total_steps=1_000_000, cycle_momentum=False, last_epoch=start_iteration - 2) for iteration in range(start_iteration, 1 + args.num_training_steps): loss = loss_func(train_iter) optimizer.zero_grad() loss.backward() optimizer.step() oclr.step() train_loss.append(loss.detach()) if iteration % (10 * args.log_freq) == 0: ckpt = f'checkpoint_{iteration:07d}.pt' print('Saving checkpoint', ckpt) torch.save( { 'iteration': iteration, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'args': args }, ckpt) if iteration % 20 == 0: with torch.no_grad(): model.eval() test_loss.append(loss_func(test_iter).detach()) model.train() if iteration % args.log_freq == 0: avg_train_loss = sum(train_loss).item() / len(train_loss) avg_test_loss = sum(test_loss).item() / len(test_loss) end_time = time.perf_counter() duration, start_time = end_time - start_time, end_time lr = oclr.get_last_lr()[0] with torch.no_grad(): model.eval() cat = random.randrange(0, len(dataset.categories)) sample = generate(model, device, cat) model.train() train_sample = next(train_iter)[0, :] test_sample = next(test_iter)[0, :] plot_encoded_figure(train_sample[:, 0].tolist(), train_sample[0, 2], 'train_sample.png') plot_encoded_figure(test_sample[:, 0].tolist(), test_sample[0, 2], 'test_sample.png') plot_encoded_figure(sample, cat, 'random_sample.png') print( f'Iteration {iteration:07d} Train loss {avg_train_loss:.3f} Test loss {avg_test_loss:.3f} LR {lr:.3e} Duration {duration:.3f}' ) if args.use_wandb: wandb.log({ 'iteration': iteration, 'train loss': avg_train_loss, 'test loss': avg_test_loss, 'duration': duration, 'learning rate': lr, 'train sample': wandb.Image('train_sample.png'), 'test sample': wandb.Image('test_sample.png'), 'random sample': wandb.Image('random_sample.png'), })
def train(args, checkpoint, mid_checkpoint_location, final_checkpoint_location, best_checkpoint_location, actfun, curr_seed, outfile_path, filename, fieldnames, curr_sample_size, device, num_params, curr_k=2, curr_p=1, curr_g=1, perm_method='shuffle'): """ Runs training session for a given randomized model :param args: arguments for this job :param checkpoint: current checkpoint :param checkpoint_location: output directory for checkpoints :param actfun: activation function currently being used :param curr_seed: seed being used by current job :param outfile_path: path to save outputs from training session :param fieldnames: column names for output file :param device: reference to CUDA device for GPU support :param num_params: number of parameters in the network :param curr_k: k value for this iteration :param curr_p: p value for this iteration :param curr_g: g value for this iteration :param perm_method: permutation strategy for our network :return: """ resnet_ver = args.resnet_ver resnet_width = args.resnet_width num_epochs = args.num_epochs actfuns_1d = ['relu', 'abs', 'swish', 'leaky_relu', 'tanh'] if actfun in actfuns_1d: curr_k = 1 kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {} if args.one_shot: util.seed_all(curr_seed) model_temp, _ = load_model(args.model, args.dataset, actfun, curr_k, curr_p, curr_g, num_params=num_params, perm_method=perm_method, device=device, resnet_ver=resnet_ver, resnet_width=resnet_width, verbose=args.verbose) util.seed_all(curr_seed) dataset_temp = util.load_dataset( args, args.model, args.dataset, seed=curr_seed, validation=True, batch_size=args.batch_size, train_sample_size=curr_sample_size, kwargs=kwargs) curr_hparams = hparams.get_hparams(args.model, args.dataset, actfun, curr_seed, num_epochs, args.search, args.hp_idx, args.one_shot) optimizer = optim.Adam(model_temp.parameters(), betas=(curr_hparams['beta1'], curr_hparams['beta2']), eps=curr_hparams['eps'], weight_decay=curr_hparams['wd'] ) start_time = time.time() oneshot_fieldnames = fieldnames if args.search else None oneshot_outfile_path = outfile_path if args.search else None lr = util.run_lr_finder( args, model_temp, dataset_temp[0], optimizer, nn.CrossEntropyLoss(), val_loader=dataset_temp[3], show=False, device=device, fieldnames=oneshot_fieldnames, outfile_path=oneshot_outfile_path, hparams=curr_hparams ) curr_hparams = {} print("Time to find LR: {}\n LR found: {:3e}".format(time.time() - start_time, lr)) else: curr_hparams = hparams.get_hparams(args.model, args.dataset, actfun, curr_seed, num_epochs, args.search, args.hp_idx) lr = curr_hparams['max_lr'] criterion = nn.CrossEntropyLoss() model, model_params = load_model(args.model, args.dataset, actfun, curr_k, curr_p, curr_g, num_params=num_params, perm_method=perm_method, device=device, resnet_ver=resnet_ver, resnet_width=resnet_width, verbose=args.verbose) util.seed_all(curr_seed) model.apply(util.weights_init) util.seed_all(curr_seed) dataset = util.load_dataset( args, args.model, args.dataset, seed=curr_seed, validation=args.validation, batch_size=args.batch_size, train_sample_size=curr_sample_size, kwargs=kwargs) loaders = { 'aug_train': dataset[0], 'train': dataset[1], 'aug_eval': dataset[2], 'eval': dataset[3], } sample_size = dataset[4] batch_size = dataset[5] if args.one_shot: optimizer = optim.Adam(model_params) scheduler = OneCycleLR(optimizer, max_lr=lr, epochs=num_epochs, steps_per_epoch=int(math.floor(sample_size / batch_size)), cycle_momentum=False ) else: optimizer = optim.Adam(model_params, betas=(curr_hparams['beta1'], curr_hparams['beta2']), eps=curr_hparams['eps'], weight_decay=curr_hparams['wd'] ) scheduler = OneCycleLR(optimizer, max_lr=curr_hparams['max_lr'], epochs=num_epochs, steps_per_epoch=int(math.floor(sample_size / batch_size)), pct_start=curr_hparams['cycle_peak'], cycle_momentum=False ) epoch = 1 if checkpoint is not None: model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) epoch = checkpoint['epoch'] model.to(device) print("*** LOADED CHECKPOINT ***" "\n{}" "\nSeed: {}" "\nEpoch: {}" "\nActfun: {}" "\nNum Params: {}" "\nSample Size: {}" "\np: {}" "\nk: {}" "\ng: {}" "\nperm_method: {}".format(mid_checkpoint_location, checkpoint['curr_seed'], checkpoint['epoch'], checkpoint['actfun'], checkpoint['num_params'], checkpoint['sample_size'], checkpoint['p'], checkpoint['k'], checkpoint['g'], checkpoint['perm_method'])) util.print_exp_settings(curr_seed, args.dataset, outfile_path, args.model, actfun, util.get_model_params(model), sample_size, batch_size, model.k, model.p, model.g, perm_method, resnet_ver, resnet_width, args.optim, args.validation, curr_hparams) best_val_acc = 0 if args.mix_pre_apex: model, optimizer = amp.initialize(model, optimizer, opt_level="O2") # ---- Start Training while epoch <= num_epochs: if args.check_path != '': torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'curr_seed': curr_seed, 'epoch': epoch, 'actfun': actfun, 'num_params': num_params, 'sample_size': sample_size, 'p': curr_p, 'k': curr_k, 'g': curr_g, 'perm_method': perm_method }, mid_checkpoint_location) util.seed_all((curr_seed * args.num_epochs) + epoch) start_time = time.time() if args.mix_pre: scaler = torch.cuda.amp.GradScaler() # ---- Training model.train() total_train_loss, n, num_correct, num_total = 0, 0, 0, 0 for batch_idx, (x, targetx) in enumerate(loaders['aug_train']): # print(batch_idx) x, targetx = x.to(device), targetx.to(device) optimizer.zero_grad() if args.mix_pre: with torch.cuda.amp.autocast(): output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 scaler.scale(train_loss).backward() scaler.step(optimizer) scaler.update() elif args.mix_pre_apex: output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 with amp.scale_loss(train_loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() else: output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 train_loss.backward() optimizer.step() if args.optim == 'onecycle' or args.optim == 'onecycle_sgd': scheduler.step() _, prediction = torch.max(output.data, 1) num_correct += torch.sum(prediction == targetx.data) num_total += len(prediction) epoch_aug_train_loss = total_train_loss / n epoch_aug_train_acc = num_correct * 1.0 / num_total alpha_primes = [] alphas = [] if model.actfun == 'combinact': for i, layer_alpha_primes in enumerate(model.all_alpha_primes): curr_alpha_primes = torch.mean(layer_alpha_primes, dim=0) curr_alphas = F.softmax(curr_alpha_primes, dim=0).data.tolist() curr_alpha_primes = curr_alpha_primes.tolist() alpha_primes.append(curr_alpha_primes) alphas.append(curr_alphas) model.eval() with torch.no_grad(): total_val_loss, n, num_correct, num_total = 0, 0, 0, 0 for batch_idx, (y, targety) in enumerate(loaders['aug_eval']): y, targety = y.to(device), targety.to(device) output = model(y) val_loss = criterion(output, targety) total_val_loss += val_loss n += 1 _, prediction = torch.max(output.data, 1) num_correct += torch.sum(prediction == targety.data) num_total += len(prediction) epoch_aug_val_loss = total_val_loss / n epoch_aug_val_acc = num_correct * 1.0 / num_total total_val_loss, n, num_correct, num_total = 0, 0, 0, 0 for batch_idx, (y, targety) in enumerate(loaders['eval']): y, targety = y.to(device), targety.to(device) output = model(y) val_loss = criterion(output, targety) total_val_loss += val_loss n += 1 _, prediction = torch.max(output.data, 1) num_correct += torch.sum(prediction == targety.data) num_total += len(prediction) epoch_val_loss = total_val_loss / n epoch_val_acc = num_correct * 1.0 / num_total lr_curr = 0 for param_group in optimizer.param_groups: lr_curr = param_group['lr'] print( " Epoch {}: LR {:1.5f} ||| aug_train_acc {:1.4f} | val_acc {:1.4f}, aug {:1.4f} ||| " "aug_train_loss {:1.4f} | val_loss {:1.4f}, aug {:1.4f} ||| time = {:1.4f}" .format(epoch, lr_curr, epoch_aug_train_acc, epoch_val_acc, epoch_aug_val_acc, epoch_aug_train_loss, epoch_val_loss, epoch_aug_val_loss, (time.time() - start_time)), flush=True ) if args.hp_idx is None: hp_idx = -1 else: hp_idx = args.hp_idx epoch_train_loss = 0 epoch_train_acc = 0 if epoch == num_epochs: with torch.no_grad(): total_train_loss, n, num_correct, num_total = 0, 0, 0, 0 for batch_idx, (x, targetx) in enumerate(loaders['aug_train']): x, targetx = x.to(device), targetx.to(device) output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 _, prediction = torch.max(output.data, 1) num_correct += torch.sum(prediction == targetx.data) num_total += len(prediction) epoch_aug_train_loss = total_train_loss / n epoch_aug_train_acc = num_correct * 1.0 / num_total total_train_loss, n, num_correct, num_total = 0, 0, 0, 0 for batch_idx, (x, targetx) in enumerate(loaders['train']): x, targetx = x.to(device), targetx.to(device) output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 _, prediction = torch.max(output.data, 1) num_correct += torch.sum(prediction == targetx.data) num_total += len(prediction) epoch_train_loss = total_val_loss / n epoch_train_acc = num_correct * 1.0 / num_total # Outputting data to CSV at end of epoch with open(outfile_path, mode='a') as out_file: writer = csv.DictWriter(out_file, fieldnames=fieldnames, lineterminator='\n') writer.writerow({'dataset': args.dataset, 'seed': curr_seed, 'epoch': epoch, 'time': (time.time() - start_time), 'actfun': model.actfun, 'sample_size': sample_size, 'model': args.model, 'batch_size': batch_size, 'alpha_primes': alpha_primes, 'alphas': alphas, 'num_params': util.get_model_params(model), 'var_nparams': args.var_n_params, 'var_nsamples': args.var_n_samples, 'k': curr_k, 'p': curr_p, 'g': curr_g, 'perm_method': perm_method, 'gen_gap': float(epoch_val_loss - epoch_train_loss), 'aug_gen_gap': float(epoch_aug_val_loss - epoch_aug_train_loss), 'resnet_ver': resnet_ver, 'resnet_width': resnet_width, 'epoch_train_loss': float(epoch_train_loss), 'epoch_train_acc': float(epoch_train_acc), 'epoch_aug_train_loss': float(epoch_aug_train_loss), 'epoch_aug_train_acc': float(epoch_aug_train_acc), 'epoch_val_loss': float(epoch_val_loss), 'epoch_val_acc': float(epoch_val_acc), 'epoch_aug_val_loss': float(epoch_aug_val_loss), 'epoch_aug_val_acc': float(epoch_aug_val_acc), 'hp_idx': hp_idx, 'curr_lr': lr_curr, 'found_lr': lr, 'hparams': curr_hparams, 'epochs': num_epochs }) epoch += 1 if args.optim == 'rmsprop': scheduler.step() if args.checkpoints: if epoch_val_acc > best_val_acc: best_val_acc = epoch_val_acc torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'curr_seed': curr_seed, 'epoch': epoch, 'actfun': actfun, 'num_params': num_params, 'sample_size': sample_size, 'p': curr_p, 'k': curr_k, 'g': curr_g, 'perm_method': perm_method }, best_checkpoint_location) torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'curr_seed': curr_seed, 'epoch': epoch, 'actfun': actfun, 'num_params': num_params, 'sample_size': sample_size, 'p': curr_p, 'k': curr_k, 'g': curr_g, 'perm_method': perm_method }, final_checkpoint_location)
class Learner: def __init__(self, model, train_loader, valid_loader, fold, config, seed): self.config = config self.seed = seed self.device = self.config.device self.train_loader = train_loader self.valid_loader = valid_loader self.model = model.to(self.device) self.fold = fold self.logger = init_logger( config.log_dir, f'train_seed{self.seed}_fold{self.fold}.log') self.tb_logger = init_tb_logger( config.log_dir, f'train_seed{self.seed}_fold{self.fold}') if self.fold == 0: self.log('\n'.join( [f"{k} = {v}" for k, v in self.config.__dict__.items()])) self.criterion = SmoothBCEwLogits(smoothing=self.config.smoothing) self.evaluator = nn.BCEWithLogitsLoss() self.summary_loss = AverageMeter() self.history = {'train': [], 'valid': []} self.optimizer = Adam(self.model.parameters(), lr=config.lr, weight_decay=self.config.weight_decay) self.scheduler = OneCycleLR(optimizer=self.optimizer, pct_start=0.1, div_factor=1e3, max_lr=1e-2, epochs=config.n_epochs, steps_per_epoch=len(train_loader)) self.scaler = GradScaler() if config.fp16 else None self.epoch = 0 self.best_epoch = 0 self.best_loss = np.inf def train_one_epoch(self): self.model.train() self.summary_loss.reset() iters = len(self.train_loader) for step, (g_x, c_x, cate_x, labels, non_labels) in enumerate(self.train_loader): self.optimizer.zero_grad() # self.tb_logger.add_scalar('Train/lr', self.optimizer.param_groups[0]['lr'], # iters * self.epoch + step) labels = labels.to(self.device) non_labels = non_labels.to(self.device) g_x = g_x.to(self.device) c_x = c_x.to(self.device) cate_x = cate_x.to(self.device) batch_size = labels.shape[0] with ExitStack() as stack: if self.config.fp16: auto = stack.enter_context(autocast()) outputs = self.model(g_x, c_x, cate_x) loss = self.criterion(outputs, labels) if self.config.fp16: self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: loss.backward() self.optimizer.step() self.summary_loss.update(loss.item(), batch_size) if self.scheduler.__class__.__name__ != 'ReduceLROnPlateau': self.scheduler.step() self.history['train'].append(self.summary_loss.avg) return self.summary_loss.avg def validation(self): self.model.eval() self.summary_loss.reset() iters = len(self.valid_loader) for step, (g_x, c_x, cate_x, labels, non_labels) in enumerate(self.valid_loader): with torch.no_grad(): labels = labels.to(self.device) g_x = g_x.to(self.device) c_x = c_x.to(self.device) cate_x = cate_x.to(self.device) batch_size = labels.shape[0] outputs = self.model(g_x, c_x, cate_x) loss = self.evaluator(outputs, labels) self.summary_loss.update(loss.detach().item(), batch_size) self.history['valid'].append(self.summary_loss.avg) return self.summary_loss.avg def fit(self, epochs): self.log(f'Start training....') for e in range(epochs): t = time.time() loss = self.train_one_epoch() # self.log(f'[Train] \t Epoch: {self.epoch}, loss: {loss:.6f}, time: {(time.time() - t):.2f}') self.tb_logger.add_scalar('Train/Loss', loss, self.epoch) t = time.time() loss = self.validation() # self.log(f'[Valid] \t Epoch: {self.epoch}, loss: {loss:.6f}, time: {(time.time() - t):.2f}') self.tb_logger.add_scalar('Valid/Loss', loss, self.epoch) self.post_processing(loss) self.epoch += 1 self.log(f'best epoch: {self.best_epoch}, best loss: {self.best_loss}') return self.history def post_processing(self, loss): if loss < self.best_loss: self.best_loss = loss self.best_epoch = self.epoch self.model.eval() torch.save( { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'best_loss': self.best_loss, 'epoch': self.epoch, }, f'{os.path.join(self.config.log_dir, f"{self.config.name}_seed{self.seed}_fold{self.fold}.pth")}' ) self.log(f'best model: {self.epoch} epoch - loss: {loss:.6f}') def load(self, path): checkpoint = torch.load(path, map_location=lambda storage, loc: storage) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.best_loss = checkpoint['best_loss'] self.epoch = checkpoint['epoch'] + 1 def log(self, text): self.logger.info(text)
# move data to gpu img_cpu, label_cpu = batch img = img_cpu.to(device) label = label_cpu.to(device) # let model predict results output = model(img) # calc loss loss = nllloss(output, label) mean_loss.append(loss.cpu().item()) # backpropagate loss & adjust model loss.backward() optimizer.step() scheduler.step() # collect data for the f1 score lables_cat = torch.cat((lables_cat, label_cpu)) output_cat = torch.cat((output_cat, output.argmax(axis=1).cpu())) # calculate the f1 score train_f1 = f1_score(lables_cat, output_cat, average='macro') lables_cat = torch.empty(0, dtype=torch.long) output_cat = torch.empty(0, dtype=torch.long) for batch in tqdm(testloader, desc=f"Test {epoch}", leave=False): img_cpu, label_cpu = batch img = img_cpu.to(device) label_1hot = []
def train(args, training_features, model, tokenizer): """ Train the model """ wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=args, name=args.run_name) wandb.watch(model) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) else: amp = None # model recover recover_step = utils.get_max_epoch_model(args.output_dir) # if recover_step: # model_recover_checkpoint = os.path.join(args.output_dir, "model.{}.bin".format(recover_step)) # logger.info(" ** Recover model checkpoint in %s ** ", model_recover_checkpoint) # model_state_dict = torch.load(model_recover_checkpoint, map_location='cpu') # optimizer_recover_checkpoint = os.path.join(args.output_dir, "optim.{}.bin".format(recover_step)) # checkpoint_state_dict = torch.load(optimizer_recover_checkpoint, map_location='cpu') # checkpoint_state_dict['model'] = model_state_dict # else: checkpoint_state_dict = None model.to(args.device) model, optimizer = prepare_for_training(args, model, checkpoint_state_dict, amp=amp) if args.n_gpu == 0 or args.no_cuda: per_node_train_batch_size = args.per_gpu_train_batch_size * args.gradient_accumulation_steps else: per_node_train_batch_size = args.per_gpu_train_batch_size * args.n_gpu * args.gradient_accumulation_steps train_batch_size = per_node_train_batch_size * ( torch.distributed.get_world_size() if args.local_rank != -1 else 1) global_step = recover_step if recover_step else 0 if args.num_training_steps == -1: args.num_training_steps = int(args.num_training_epochs * len(training_features) / train_batch_size) if args.warmup_portion: args.num_warmup_steps = args.warmup_portion * args.num_training_steps if args.scheduler == "linear": scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.num_warmup_steps, num_training_steps=args.num_training_steps, last_epoch=-1) elif args.scheduler == "constant": scheduler = get_constant_schedule(optimizer, last_epoch=-1) elif args.scheduler == "1cycle": scheduler = OneCycleLR(optimizer, max_lr=args.learning_rate, total_steps=args.num_training_steps, pct_start=args.warmup_portion, anneal_strategy=args.anneal_strategy, final_div_factor=1e4, last_epoch=-1) else: assert False if checkpoint_state_dict: scheduler.load_state_dict(checkpoint_state_dict["lr_scheduler"]) train_dataset = utils.Seq2seqDatasetForBert( features=training_features, max_source_len=args.max_source_seq_length, max_target_len=args.max_target_seq_length, vocab_size=tokenizer.vocab_size, cls_id=tokenizer.cls_token_id, sep_id=tokenizer.sep_token_id, pad_id=tokenizer.pad_token_id, mask_id=tokenizer.mask_token_id, random_prob=args.random_prob, keep_prob=args.keep_prob, offset=train_batch_size * global_step, num_training_instances=train_batch_size * args.num_training_steps, ) logger.info("Check dataset:") for i in range(5): source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens = train_dataset.__getitem__( i) logger.info("Instance-%d" % i) logger.info("Source tokens = %s" % " ".join(tokenizer.convert_ids_to_tokens(source_ids))) logger.info("Target tokens = %s" % " ".join(tokenizer.convert_ids_to_tokens(target_ids))) logger.info("Mode = %s" % str(model)) # Train! logger.info(" ***** Running training ***** *") logger.info(" Num examples = %d", len(training_features)) logger.info(" Num Epochs = %.2f", len(train_dataset) / len(training_features)) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Batch size per node = %d", per_node_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", train_batch_size) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", args.num_training_steps) if args.num_training_steps <= global_step: logger.info( "Training is done. Please use a new dir or clean this dir!") else: # The training features are shuffled train_sampler = SequentialSampler(train_dataset) \ if args.local_rank == -1 else DistributedSampler(train_dataset, shuffle=False) train_dataloader = DataLoader( train_dataset, sampler=train_sampler, batch_size=per_node_train_batch_size // args.gradient_accumulation_steps, collate_fn=utils.batch_list_to_batch_tensors) train_iterator = tqdm.tqdm(train_dataloader, initial=global_step, desc="Iter (loss=X.XXX, lr=X.XXXXXXX)", disable=args.local_rank not in [-1, 0]) model.train() model.zero_grad() tr_loss, logging_loss = 0.0, 0.0 for step, batch in enumerate(train_iterator): batch = tuple(t.to(args.device) for t in batch) inputs = { 'source_ids': batch[0], 'target_ids': batch[1], 'pseudo_ids': batch[2], 'num_source_tokens': batch[3], 'num_target_tokens': batch[4] } loss = model(**inputs) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel (not distributed) training train_iterator.set_description( 'Iter (loss=%5.3f) lr=%9.7f' % (loss.item(), scheduler.get_last_lr()[0])) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() logging_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: wandb.log( { 'lr': scheduler.get_last_lr()[0], 'loss': logging_loss / args.logging_steps }, step=global_step) logger.info(" Step [%d ~ %d]: %.2f", global_step - args.logging_steps, global_step, logging_loss) logging_loss = 0.0 if args.local_rank in [-1, 0] and args.save_steps > 0 and \ (global_step % args.save_steps == 0 or global_step == args.num_training_steps): save_path = os.path.join(args.output_dir, "ckpt-%d" % global_step) os.makedirs(save_path, exist_ok=True) model_to_save = model.module if hasattr( model, "module") else model model_to_save.save_pretrained(save_path) # optim_to_save = { # "optimizer": optimizer.state_dict(), # "lr_scheduler": scheduler.state_dict(), # } # if args.fp16: # optim_to_save["amp"] = amp.state_dict() # torch.save( # optim_to_save, os.path.join(args.output_dir, 'optim.{}.bin'.format(global_step))) logger.info("Saving model checkpoint %d into %s", global_step, save_path) wandb.save(f'{save_path}/*')
class Maskv3Agent: def __init__(self, config): self.config = config # Train on device target_device = config['train']['device'] if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True self.device = target_device else: self.device = "cpu" # Load dataset train_transform = get_yolo_transform(config['dataset']['size'], mode='train') valid_transform = get_yolo_transform(config['dataset']['size'], mode='test') train_dataset = YOLOMaskDataset( csv_file=config['dataset']['train']['csv'], img_dir=config['dataset']['train']['img_root'], mask_dir=config['dataset']['train']['mask_root'], anchors=config['dataset']['anchors'], scales=config['dataset']['scales'], n_classes=config['dataset']['n_classes'], transform=train_transform) valid_dataset = YOLOMaskDataset( csv_file=config['dataset']['valid']['csv'], img_dir=config['dataset']['valid']['img_root'], mask_dir=config['dataset']['valid']['mask_root'], anchors=config['dataset']['anchors'], scales=config['dataset']['scales'], n_classes=config['dataset']['n_classes'], transform=valid_transform) # DataLoader self.train_loader = DataLoader( dataset=train_dataset, batch_size=config['dataloader']['batch_size'], num_workers=config['dataloader']['num_workers'], collate_fn=maskv3_collate_fn, pin_memory=True, shuffle=True, drop_last=False) self.valid_loader = DataLoader( dataset=valid_dataset, batch_size=config['dataloader']['batch_size'], num_workers=config['dataloader']['num_workers'], collate_fn=maskv3_collate_fn, pin_memory=True, shuffle=False, drop_last=False) # Model model = Maskv3( # Detection Branch in_channels=config['model']['in_channels'], num_classes=config['model']['num_classes'], # Prototype Branch num_masks=config['model']['num_masks'], num_features=config['model']['num_features'], ) self.model = model.to(self.device) # Faciliated Anchor boxes with model torch_anchors = torch.tensor(config['dataset']['anchors']) # (3, 3, 2) torch_scales = torch.tensor(config['dataset']['scales']) # (3,) scaled_anchors = ( # (3, 3, 2) torch_anchors * (torch_scales.unsqueeze(1).unsqueeze(1).repeat(1, 3, 2))) self.scaled_anchors = scaled_anchors.to(self.device) # Optimizer self.scaler = torch.cuda.amp.GradScaler() self.optimizer = optim.Adam( params=self.model.parameters(), lr=config['optimizer']['lr'], weight_decay=config['optimizer']['weight_decay'], ) # Scheduler self.scheduler = OneCycleLR( self.optimizer, max_lr=config['optimizer']['lr'], epochs=config['train']['n_epochs'], steps_per_epoch=len(self.train_loader), ) # Loss function self.loss_fn = YOLOMaskLoss(num_classes=config['model']['num_classes'], num_masks=config['model']['num_masks']) # Tensorboard self.logdir = config['train']['logdir'] self.board = SummaryWriter(logdir=config['train']['logdir']) # Training State self.current_epoch = 0 self.current_map = 0 def resume(self): checkpoint_path = osp.join(self.logdir, 'best.pth') checkpoint = torch.load(checkpoint_path) self.model.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) self.current_map = checkpoint['current_map'] self.current_epoch = checkpoint['current_epoch'] print("Restore checkpoint at '{}'".format(self.current_epoch)) def train(self): for epoch in range(self.current_epoch + 1, self.config['train']['n_epochs'] + 1): self.current_epoch = epoch self._train_one_epoch() self._validate() accs = self._check_accuracy() if self.current_epoch < self.config['valid']['when']: self._save_checkpoint() if (self.current_epoch >= self.config['valid']['when'] and self.current_epoch % 5 == 0): mAP50 = self._check_map() if mAP50 > self.current_map: self.current_map = mAP50 self._save_checkpoint() def finalize(self): self._check_map() def _train_one_epoch(self): n_epochs = self.config['train']['n_epochs'] current_epoch = self.current_epoch current_lr = self.optimizer.param_groups[0]['lr'] loop = tqdm(self.train_loader, leave=True, desc=(f"Train Epoch:{current_epoch}/{n_epochs}" f", LR: {current_lr:.5f}")) obj_losses = [] box_losses = [] noobj_losses = [] class_losses = [] total_losses = [] segment_losses = [] self.model.train() for batch_idx, (imgs, masks, targets) in enumerate(loop): # Move device imgs = imgs.to(self.device) # (N, 3, 416, 416) masks = [m.to(self.device) for m in masks] # (nM_g, H, W) target_s1 = targets[0].to(self.device) # (N, 3, 13, 13, 6) target_s2 = targets[1].to(self.device) # (N, 3, 26, 26, 6) target_s3 = targets[2].to(self.device) # (N, 3, 52, 52, 6) # Model prediction with torch.cuda.amp.autocast(): outs, prototypes = self.model(imgs) s1_loss = self.loss_fn( outs[0], target_s1, self.scaled_anchors[0], # Detection Branch prototypes, masks, # Prototype Branch ) s2_loss = self.loss_fn( outs[1], target_s2, self.scaled_anchors[1], # Detection Branch prototypes, masks, # Prototype Branch ) s3_loss = self.loss_fn( outs[2], target_s3, self.scaled_anchors[2], # Detection Branch prototypes, masks, # Prototype Branch ) # Aggregate loss obj_loss = s1_loss['obj_loss'] + s2_loss['obj_loss'] + s3_loss[ 'obj_loss'] box_loss = s1_loss['box_loss'] + s2_loss['box_loss'] + s3_loss[ 'box_loss'] noobj_loss = s1_loss['noobj_loss'] + s2_loss[ 'noobj_loss'] + s3_loss['noobj_loss'] class_loss = s1_loss['class_loss'] + s2_loss[ 'class_loss'] + s3_loss['class_loss'] segment_loss = s1_loss['segment_loss'] + s2_loss[ 'segment_loss'] + s3_loss['segment_loss'] total_loss = s1_loss['total_loss'] + s2_loss[ 'total_loss'] + s3_loss['total_loss'] # Moving average loss total_losses.append(total_loss.item()) obj_losses.append(obj_loss.item()) noobj_losses.append(noobj_loss.item()) box_losses.append(box_loss.item()) class_losses.append(class_loss.item()) segment_losses.append(segment_loss.item()) # Update Parameters self.optimizer.zero_grad() self.scaler.scale(total_loss).backward() self.scaler.step(self.optimizer) self.scaler.update() self.scheduler.step() # Upadte progress bar mean_total_loss = sum(total_losses) / len(total_losses) mean_obj_loss = sum(obj_losses) / len(obj_losses) mean_noobj_loss = sum(noobj_losses) / len(noobj_losses) mean_box_loss = sum(box_losses) / len(box_losses) mean_class_loss = sum(class_losses) / len(class_losses) mean_segment_loss = sum(segment_losses) / len(segment_losses) loop.set_postfix( loss=mean_total_loss, cls=mean_class_loss, box=mean_box_loss, obj=mean_obj_loss, noobj=mean_noobj_loss, segment=mean_segment_loss, ) # Logging (epoch) epoch_total_loss = sum(total_losses) / len(total_losses) epoch_obj_loss = sum(obj_losses) / len(obj_losses) epoch_noobj_loss = sum(noobj_losses) / len(noobj_losses) epoch_box_loss = sum(box_losses) / len(box_losses) epoch_class_loss = sum(class_losses) / len(class_losses) epoch_segment_loss = sum(segment_losses) / len(segment_losses) self.board.add_scalar('Epoch Train Loss', epoch_total_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Train BOX Loss', epoch_box_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Train OBJ Loss', epoch_obj_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Train NOOBJ Loss', epoch_noobj_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Train CLASS Loss', epoch_class_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Train SEGMENT Loss', epoch_segment_loss, global_step=self.current_epoch) def _validate(self): n_epochs = self.config['train']['n_epochs'] current_epoch = self.current_epoch current_lr = self.optimizer.param_groups[0]['lr'] loop = tqdm(self.valid_loader, leave=True, desc=(f"Valid Epoch:{current_epoch}/{n_epochs}" f", LR: {current_lr:.5f}")) obj_losses = [] box_losses = [] noobj_losses = [] class_losses = [] total_losses = [] segment_losses = [] self.model.eval() for batch_idx, (imgs, masks, targets) in enumerate(loop): # Move device imgs = imgs.to(self.device) # (N, 3, 416, 416) masks = [m.to(self.device) for m in masks] # (nM_g, H, W) target_s1 = targets[0].to(self.device) # (N, 3, 13, 13, 6) target_s2 = targets[1].to(self.device) # (N, 3, 26, 26, 6) target_s3 = targets[2].to(self.device) # (N, 3, 52, 52, 6) # Model Prediction with torch.no_grad(): with torch.cuda.amp.autocast(): outs, prototypes = self.model(imgs) s1_loss = self.loss_fn( outs[0], target_s1, self.scaled_anchors[0], # Detection Branch prototypes, masks, # Prototype Branch ) s2_loss = self.loss_fn( outs[1], target_s2, self.scaled_anchors[1], # Detection Branch prototypes, masks, # Prototype Branch ) s3_loss = self.loss_fn( outs[2], target_s3, self.scaled_anchors[2], # Detection Branch prototypes, masks, # Prototype Branch ) # Aggregate loss obj_loss = s1_loss['obj_loss'] + s2_loss['obj_loss'] + s3_loss[ 'obj_loss'] box_loss = s1_loss['box_loss'] + s2_loss['box_loss'] + s3_loss[ 'box_loss'] noobj_loss = s1_loss['noobj_loss'] + s2_loss[ 'noobj_loss'] + s3_loss['noobj_loss'] class_loss = s1_loss['class_loss'] + s2_loss[ 'class_loss'] + s3_loss['class_loss'] segment_loss = s1_loss['segment_loss'] + s2_loss[ 'segment_loss'] + s3_loss['segment_loss'] total_loss = s1_loss['total_loss'] + s2_loss[ 'total_loss'] + s3_loss['total_loss'] # Moving average loss obj_losses.append(obj_loss.item()) box_losses.append(box_loss.item()) noobj_losses.append(noobj_loss.item()) class_losses.append(class_loss.item()) total_losses.append(total_loss.item()) segment_losses.append(segment_loss.item()) # Upadte progress bar mean_total_loss = sum(total_losses) / len(total_losses) mean_obj_loss = sum(obj_losses) / len(obj_losses) mean_noobj_loss = sum(noobj_losses) / len(noobj_losses) mean_box_loss = sum(box_losses) / len(box_losses) mean_class_loss = sum(class_losses) / len(class_losses) mean_segment_loss = sum(segment_losses) / len(segment_losses) loop.set_postfix( loss=mean_total_loss, cls=mean_class_loss, box=mean_box_loss, obj=mean_obj_loss, noobj=mean_noobj_loss, segment=mean_segment_loss, ) # Logging (epoch) epoch_total_loss = sum(total_losses) / len(total_losses) epoch_obj_loss = sum(obj_losses) / len(obj_losses) epoch_noobj_loss = sum(noobj_losses) / len(noobj_losses) epoch_box_loss = sum(box_losses) / len(box_losses) epoch_class_loss = sum(class_losses) / len(class_losses) epoch_segment_loss = sum(segment_losses) / len(segment_losses) self.board.add_scalar('Epoch Valid Loss', epoch_total_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Valid BOX Loss', epoch_box_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Valid OBJ Loss', epoch_obj_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Valid NOOBJ Loss', epoch_noobj_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Valid CLASS Loss', epoch_class_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Valid SEGMENT Loss', epoch_segment_loss, global_step=self.current_epoch) def _check_accuracy(self): tot_obj = 0 tot_noobj = 0 correct_obj = 0 correct_noobj = 0 correct_class = 0 self.model.eval() loop = tqdm(self.valid_loader, leave=True, desc=f"Check ACC") for batch_idx, (imgs, masks, targets) in enumerate(loop): batch_size = imgs.size(0) # Move device imgs = imgs.to(self.device) # (N, 3, 416, 416) target_s1 = targets[0].to(self.device) # (N, 3, 13, 13, 6) target_s2 = targets[1].to(self.device) # (N, 3, 26, 26, 6) target_s3 = targets[2].to(self.device) # (N, 3, 52, 52, 6) targets = [target_s1, target_s2, target_s3] # Model Prediction with torch.no_grad(): with torch.cuda.amp.autocast(): outs, prototypes = self.model(imgs) for scale_idx in range(len(outs)): # Get output pred = outs[scale_idx] target = targets[scale_idx] # Get mask obj_mask = target[..., 4] == 1 noobj_mask = target[..., 4] == 0 # Count objects tot_obj += torch.sum(obj_mask) tot_noobj += torch.sum(noobj_mask) # Exception Handling if torch.sum(obj_mask) == 0: obj_pred = torch.sigmoid( pred[..., 4]) > self.config['valid']['conf_threshold'] correct_noobj += torch.sum( obj_pred[noobj_mask] == target[..., 4][noobj_mask]) continue # Count number of correct classified object correct_class += torch.sum((torch.argmax( pred[..., 5:5 + self.config['model']['num_classes']][obj_mask], dim=-1) == target[..., 5][obj_mask])) # Count number of correct objectness & non-objectness obj_pred = torch.sigmoid( pred[..., 4]) > self.config['valid']['conf_threshold'] correct_obj += torch.sum( obj_pred[obj_mask] == target[..., 4][obj_mask]) correct_noobj += torch.sum( obj_pred[noobj_mask] == target[..., 4][noobj_mask]) # Aggregation Result acc_obj = (correct_obj / (tot_obj + 1e-6)) * 100 acc_cls = (correct_class / (tot_obj + 1e-6)) * 100 acc_noobj = (correct_noobj / (tot_noobj + 1e-6)) * 100 accs = { 'cls': acc_cls.item(), 'obj': acc_obj.item(), 'noobj': acc_noobj.item() } print(f"Epoch {self.current_epoch} [Accs]: {accs}") return accs def _check_map(self): sample_idx = 0 all_pred_bboxes = [] all_true_bboxes = [] self.model.eval() loop = tqdm(self.valid_loader, leave=True, desc="Check mAP") for batch_idx, (imgs, masks, targets) in enumerate(loop): batch_size = imgs.size(0) # Move device imgs = imgs.to(self.device) # (N, 3, 416, 416) target_s1 = targets[0].to(self.device) # (N, 3, 13, 13, 6) target_s2 = targets[1].to(self.device) # (N, 3, 26, 26, 6) target_s3 = targets[2].to(self.device) # (N, 3, 52, 52, 6) targets = [target_s1, target_s2, target_s3] # Model Forward with torch.no_grad(): with torch.cuda.amp.autocast(): preds, prototypes = self.model(imgs) # Convert cells to bboxes # ================================================================= true_bboxes = [[] for _ in range(batch_size)] pred_bboxes = [[] for _ in range(batch_size)] for scale_idx, (pred, target) in enumerate(zip(preds, targets)): scale = pred.size(2) anchors = self.scaled_anchors[scale_idx] # (3, 2) anchors = anchors.reshape(1, 3, 1, 1, 2) # (1, 3, 1, 1, 2) # Convert prediction to correct format pred[..., 0:2] = torch.sigmoid(pred[..., 0:2]) # (N, 3, S, S, 2) pred[..., 2:4] = torch.exp( pred[..., 2:4]) * anchors # (N, 3, S, S, 2) pred[..., 4:5] = torch.sigmoid(pred[..., 4:5]) # (N, 3, S, S, 1) pred_cls_probs = F.softmax( pred[..., 5:5 + self.config['model']['num_classes']], dim=-1) # (N, 3, S, S, C) _, indices = torch.max(pred_cls_probs, dim=-1) # (N, 3, S, S) indices = indices.unsqueeze(-1) # (N, 3, S, S, 1) pred = torch.cat([pred[..., :5], indices], dim=-1) # (N, 3, S, S, 6) # Convert coordinate system to normalized format (xywh) pboxes = cells_to_boxes(cells=pred, scale=scale) # (N, 3, S, S, 6) tboxes = cells_to_boxes(cells=target, scale=scale) # (N, 3, S, S, 6) # Filter out bounding boxes from all cells for idx, cell_boxes in enumerate(pboxes): obj_mask = cell_boxes[ ..., 4] > self.config['valid']['conf_threshold'] boxes = cell_boxes[obj_mask] pred_bboxes[idx] += boxes.tolist() # Filter out bounding boxes from all cells for idx, cell_boxes in enumerate(tboxes): obj_mask = cell_boxes[..., 4] > 0.99 boxes = cell_boxes[obj_mask] true_bboxes[idx] += boxes.tolist() # Perform NMS batch-by-batch # ================================================================= for batch_idx in range(batch_size): pbboxes = torch.tensor(pred_bboxes[batch_idx]) tbboxes = torch.tensor(true_bboxes[batch_idx]) # Perform NMS class-by-class for c in range(self.config['model']['num_classes']): # Filter pred boxes of specific class nms_pred_boxes = nms_by_class( target=c, bboxes=pbboxes, iou_threshold=self.config['valid'] ['nms_iou_threshold']) nms_true_boxes = nms_by_class( target=c, bboxes=tbboxes, iou_threshold=self.config['valid'] ['nms_iou_threshold']) all_pred_bboxes.extend([[sample_idx] + box for box in nms_pred_boxes]) all_true_bboxes.extend([[sample_idx] + box for box in nms_true_boxes]) sample_idx += 1 # Compute [email protected] & [email protected] # ================================================================= # The format of the bboxes is (idx, x1, y1, x2, y2, conf, class) all_pred_bboxes = torch.tensor(all_pred_bboxes) # (J, 7) all_true_bboxes = torch.tensor(all_true_bboxes) # (K, 7) eval50 = mean_average_precision( all_pred_bboxes, all_true_bboxes, iou_threshold=0.5, n_classes=self.config['dataset']['n_classes']) eval75 = mean_average_precision( all_pred_bboxes, all_true_bboxes, iou_threshold=0.75, n_classes=self.config['dataset']['n_classes']) print(( f"Epoch {self.current_epoch}:\n" f"\t-[[email protected]]={eval50['mAP']:.3f}, [Recall]={eval50['recall']:.3f}, [Precision]={eval50['precision']:.3f}\n" f"\t-[[email protected]]={eval75['mAP']:.3f}, [Recall]={eval75['recall']:.3f}, [Precision]={eval75['precision']:.3f}\n" )) return eval50['mAP'] def _save_checkpoint(self): checkpoint = { 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict(), 'current_map': self.current_map, 'current_epoch': self.current_epoch } checkpoint_path = osp.join(self.logdir, 'best.pth') torch.save(checkpoint, checkpoint_path) print("Save checkpoint at '{}'".format(checkpoint_path))
def __call__( self, net: nn.Module, train_iter: DataLoader, validation_iter: Optional[DataLoader] = None, ) -> None: wandb.watch(net, log="all", log_freq=self.num_batches_per_epoch) optimizer = Adam(net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) lr_scheduler = OneCycleLR( optimizer, max_lr=self.maximum_learning_rate, steps_per_epoch=self.num_batches_per_epoch, epochs=self.epochs, ) for epoch_no in range(self.epochs): # mark epoch start time tic = time.time() avg_epoch_loss = 0.0 if validation_iter is not None: avg_epoch_loss_val = 0.0 train_iter_obj = list( zip(range(1, train_iter.batch_size + 1), tqdm(train_iter))) if validation_iter is not None: val_iter_obj = list( zip(range(1, validation_iter.batch_size + 1), tqdm(validation_iter))) with tqdm(train_iter) as it: for batch_no, data_entry in train_iter_obj: optimizer.zero_grad() # Strong assumption that validation_iter and train_iter are same iter size if validation_iter is not None: with torch.no_grad(): val_data_entry = val_iter_obj[batch_no - 1][1] inputs_val = [ v.to(self.device) for v in val_data_entry.values() ] output_val = net(*inputs_val) if isinstance(output_val, (list, tuple)): loss_val = output_val[0] else: loss_val = output_val avg_epoch_loss_val += loss_val.item() inputs = [v.to(self.device) for v in data_entry.values()] output = net(*inputs) if isinstance(output, (list, tuple)): loss = output[0] else: loss = output avg_epoch_loss += loss.item() if validation_iter is not None: post_fix_dict = ordered_dict = { "avg_epoch_loss": avg_epoch_loss / batch_no, "avg_epoch_loss_val": avg_epoch_loss_val / batch_no, "epoch": epoch_no, } wandb.log({"loss_val": loss_val.item()}) else: post_fix_dict = { "avg_epoch_loss": avg_epoch_loss / batch_no, "epoch": epoch_no, } wandb.log({"loss": loss.item()}) it.set_postfix(post_fix_dict, refresh=False) loss.backward() if self.clip_gradient is not None: nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient) optimizer.step() lr_scheduler.step() if self.num_batches_per_epoch == batch_no: break # mark epoch end time and log time cost of current epoch toc = time.time()
class Trainer(): def __init__(self, alphabets_, list_ngram): self.vocab = Vocab(alphabets_) self.synthesizer = SynthesizeData(vocab_path="") self.list_ngrams_train, self.list_ngrams_valid = self.train_test_split( list_ngram, test_size=0.1) print("Loaded data!!!") print("Total training samples: ", len(self.list_ngrams_train)) print("Total valid samples: ", len(self.list_ngrams_valid)) INPUT_DIM = self.vocab.__len__() OUTPUT_DIM = self.vocab.__len__() self.device = DEVICE self.num_iters = NUM_ITERS self.beamsearch = BEAM_SEARCH self.batch_size = BATCH_SIZE self.print_every = PRINT_PER_ITER self.valid_every = VALID_PER_ITER self.checkpoint = CHECKPOINT self.export_weights = EXPORT self.metrics = MAX_SAMPLE_VALID logger = LOG if logger: self.logger = Logger(logger) self.iter = 0 self.model = Seq2Seq(input_dim=INPUT_DIM, output_dim=OUTPUT_DIM, encoder_embbeded=ENC_EMB_DIM, decoder_embedded=DEC_EMB_DIM, encoder_hidden=ENC_HID_DIM, decoder_hidden=DEC_HID_DIM, encoder_dropout=ENC_DROPOUT, decoder_dropout=DEC_DROPOUT) self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) self.scheduler = OneCycleLR(self.optimizer, total_steps=self.num_iters, pct_start=PCT_START, max_lr=MAX_LR) self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1) self.train_gen = self.data_gen(self.list_ngrams_train, self.synthesizer, self.vocab, is_train=True) self.valid_gen = self.data_gen(self.list_ngrams_valid, self.synthesizer, self.vocab, is_train=False) self.train_losses = [] # to device self.model.to(self.device) self.criterion.to(self.device) def train_test_split(self, list_phrases, test_size=0.1): list_phrases = list_phrases train_idx = int(len(list_phrases) * (1 - test_size)) list_phrases_train = list_phrases[:train_idx] list_phrases_valid = list_phrases[train_idx:] return list_phrases_train, list_phrases_valid def data_gen(self, list_ngrams_np, synthesizer, vocab, is_train=True): dataset = AutoCorrectDataset(list_ngrams_np, transform_noise=synthesizer, vocab=vocab, maxlen=MAXLEN) shuffle = True if is_train else False gen = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=shuffle, drop_last=False) return gen def step(self, batch): self.model.train() batch = self.batch_to_device(batch) src, tgt = batch['src'], batch['tgt'] src, tgt = src.transpose(1, 0), tgt.transpose( 1, 0) # batch x src_len -> src_len x batch outputs = self.model( src, tgt) # src : src_len x B, outpus : B x tgt_len x vocab # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) outputs = outputs.view(-1, outputs.size(2)) # flatten(0, 1) tgt_output = tgt.transpose(0, 1).reshape( -1) # flatten() # tgt: tgt_len xB , need convert to B x tgt_len loss = self.criterion(outputs, tgt_output) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() self.scheduler.step() loss_item = loss.item() return loss_item def train(self): print("Begin training from iter: ", self.iter) total_loss = 0 total_loader_time = 0 total_gpu_time = 0 best_acc = -1 data_iter = iter(self.train_gen) for i in range(self.num_iters): self.iter += 1 start = time.time() try: batch = next(data_iter) except StopIteration: data_iter = iter(self.train_gen) batch = next(data_iter) total_loader_time += time.time() - start start = time.time() loss = self.step(batch) total_gpu_time += time.time() - start total_loss += loss self.train_losses.append((self.iter, loss)) if self.iter % self.print_every == 0: info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format( self.iter, total_loss / self.print_every, self.optimizer.param_groups[0]['lr'], total_loader_time, total_gpu_time) total_loss = 0 total_loader_time = 0 total_gpu_time = 0 print(info) self.logger.log(info) if self.iter % self.valid_every == 0: val_loss, preds, actuals, inp_sents = self.validate() acc_full_seq, acc_per_char, cer = self.precision(self.metrics) info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f} - CER: {:.4f} '.format( self.iter, val_loss, acc_full_seq, acc_per_char, cer) print(info) print("--- Sentence predict ---") for pred, inp, label in zip(preds, inp_sents, actuals): infor_predict = 'Pred: {} - Inp: {} - Label: {}'.format( pred, inp, label) print(infor_predict) self.logger.log(infor_predict) self.logger.log(info) if acc_full_seq > best_acc: self.save_weights(self.export_weights) best_acc = acc_full_seq self.save_checkpoint(self.checkpoint) def validate(self): self.model.eval() total_loss = [] max_step = self.metrics / self.batch_size with torch.no_grad(): for step, batch in enumerate(self.valid_gen): batch = self.batch_to_device(batch) src, tgt = batch['src'], batch['tgt'] src, tgt = src.transpose(1, 0), tgt.transpose(1, 0) outputs = self.model(src, tgt, 0) # turn off teaching force outputs = outputs.flatten(0, 1) tgt_output = tgt.flatten() loss = self.criterion(outputs, tgt_output) total_loss.append(loss.item()) preds, actuals, inp_sents, probs = self.predict(5) del outputs del loss if step > max_step: break total_loss = np.mean(total_loss) self.model.train() return total_loss, preds[:3], actuals[:3], inp_sents[:3] def predict(self, sample=None): pred_sents = [] actual_sents = [] inp_sents = [] for batch in self.valid_gen: batch = self.batch_to_device(batch) if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['src'], self.model) prob = None else: translated_sentence, prob = translate(batch['src'], self.model) pred_sent = self.vocab.batch_decode(translated_sentence.tolist()) actual_sent = self.vocab.batch_decode(batch['tgt'].tolist()) inp_sent = self.vocab.batch_decode(batch['src'].tolist()) pred_sents.extend(pred_sent) actual_sents.extend(actual_sent) inp_sents.extend(inp_sent) if sample is not None and len(pred_sents) > sample: break return pred_sents, actual_sents, inp_sents, prob def precision(self, sample=None): pred_sents, actual_sents, _, _ = self.predict(sample=sample) acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode='full_sequence') acc_per_char = compute_accuracy(actual_sents, pred_sents, mode='per_char') cer = compute_accuracy(actual_sents, pred_sents, mode='CER') return acc_full_seq, acc_per_char, cer def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16): pred_sents, actual_sents, img_files, probs = self.predict(sample) if errorcase: wrongs = [] for i in range(len(img_files)): if pred_sents[i] != actual_sents[i]: wrongs.append(i) pred_sents = [pred_sents[i] for i in wrongs] actual_sents = [actual_sents[i] for i in wrongs] img_files = [img_files[i] for i in wrongs] probs = [probs[i] for i in wrongs] img_files = img_files[:sample] fontdict = {'family': fontname, 'size': fontsize} def visualize_dataset(self, sample=16, fontname='serif'): n = 0 for batch in self.train_gen: for i in range(self.batch_size): img = batch['img'][i].numpy().transpose(1, 2, 0) sent = self.vocab.decode(batch['tgt_input'].T[i].tolist()) n += 1 if n >= sample: return def load_checkpoint(self, filename): checkpoint = torch.load(filename) self.optimizer.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) self.model.load_state_dict(checkpoint['state_dict']) self.iter = checkpoint['iter'] self.train_losses = checkpoint['train_losses'] def save_checkpoint(self, filename): state = { 'iter': self.iter, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses, 'scheduler': self.scheduler.state_dict() } path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(state, filename) def load_weights(self, filename): state_dict = torch.load(filename, map_location=torch.device(self.device)) for name, param in self.model.named_parameters(): if name not in state_dict: print('{} not found'.format(name)) elif state_dict[name].shape != param.shape: print('{} missmatching shape, required {} but found {}'.format( name, param.shape, state_dict[name].shape)) del state_dict[name] self.model.load_state_dict(state_dict, strict=False) def save_weights(self, filename): path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(self.model.state_dict(), filename) def batch_to_device(self, batch): src = batch['src'].to(self.device, non_blocking=True) tgt = batch['tgt'].to(self.device, non_blocking=True) batch = {'src': src, 'tgt': tgt} return batch
def train(experiment_no: int, experiment_type: str, train_dataloader: DataLoader, val_dataloader: DataLoader, model: nn.Module, loss_func: Callable, lr: float, epochs: int, rho: float, params_to_train: Optional[int] = None, k: Optional[int] = None) -> None: """ Trainer for model using SAM + SGD with once cycle lr schedule on single gpu :param experiment_no: uid for saving outputs into log and saving model checkpoint :param experiment_type: :param val_dataloader: validation dataloader :param train_dataloader: train dataloader :param model: :param loss_func: :param lr: learning rate :param rho: rho param for SAM :param epochs: number of epochs to train for :param params_to_train: Optional - number of layers to train, if 1 - just the top layer :param k: Optional - top-k for evaluation (default is 5) :return: nothing """ log = TrainLogger(experiment_type=experiment_type, experiment_no=experiment_no) # log hyperparams log.log_hyperparameters(lr=lr, rho=rho, epochs=epochs, batch_size=train_dataloader.batch_size, model_name=type(model).__name__, params_to_train=params_to_train) torch.manual_seed(0) device = ("cuda" if torch.cuda.is_available() else "cpu") total_params = sum(p.numel() for p in model.parameters()) if params_to_train: # set only the last params_to_train params to be trainable count = 0 for param in model.parameters(): if total_params - count > params_to_train + 1: param.requires_grad = False count += 1 model = model.to(device) trainable_params = [p for p in model.parameters() if p.requires_grad] base_optimizer = optim.SGD optimizer = SAM(trainable_params, base_optimizer, rho=rho, lr=lr, momentum=0.9, weight_decay=1e-5) lr_sched = OneCycleLR(optimizer=optimizer, max_lr=lr, epochs=epochs, steps_per_epoch=len(train_dataloader)) log.log_hyperparameters(optimizer=type(optimizer).__name__, lr_sched=type(lr_sched).__name__) for epoch in range(epochs): train_dataloader = tqdm(train_dataloader) running_loss = 0 for i, data in enumerate(train_dataloader): model.train() images, labels = (d.to(device) for d in data) # first forward-backward step outputs = model(images) loss = loss_func(outputs, labels) loss.mean().backward() optimizer.first_step(zero_grad=True) # second forward-backward step outputs2 = model(images) loss2 = loss_func(outputs2, labels) loss2.mean().backward() optimizer.second_step(zero_grad=True) running_loss = loss.mean() + loss2.mean() lr_sched.step() train_result = f"Train loss after {epoch + 1} epochs: {running_loss}" print(train_result) log.log_result(train_result, model=model, epoch=epoch) evaluate_classifier(val_dataloader=val_dataloader, device=device, model=model, k=k, logger=log)
def train(args, writer): # 1.数据处理 # 获得预定义的fields,划分过的训练数据集 # train_dataset中的每一行是一个torchtext.data.Example对象,这个对象的'id': ,'category': ,'news_text': 这三个属性保存了原来csv中每一行的数据 # 此时还未数字化,要等到构造迭代器的时候才数字化 fields, train_dataset = build_and_cache_dataset(args, mode='train') # NEWS_TEXT,CATEGORY是要存词汇表的,之后构造迭代器的时候会用上 ID, CATEGORY, NEWS_TEXT = fields # 词向量 vectors = Vectors(name=args.embed_path, cache=args.data_dir) # import gensim # word2vec = gensim.models.KeyedVectors.load_word2vec_format(args.embed_path, binary=True) # 创建数据集的词汇表,同时加载预训练的词向量 # 创建词汇表,作为一个Vocab对象,存在Field对象NEWS_TEXT里,其中stoi是词和数字的映射字典,vectors是词的词向量矩阵,两者是对应的,第一个词映射为0,且词向量在vectors里也是第一行 NEWS_TEXT.build_vocab( train_dataset, # 根据训练数据集创建词汇表 max_size=args.vocab_size, # 句子最大长度 vectors=vectors, # 根据词汇表,从加载的预训练词向量中抽出相应的词向量 unk_init=torch.nn.init.xavier_normal_, ) # 创建标签的词汇表,作为一个Vocab对象,存在Field对象CATEGORY里 CATEGORY.build_vocab(train_dataset) # 实例化模型 model = TextClassifier( vocab_size=len(NEWS_TEXT.vocab), # 训练集划分后的词的总个数,即词汇表长度 output_dim=args.num_labels, # 类别数 pad_idx=NEWS_TEXT.vocab.stoi[ NEWS_TEXT. pad_token], # NEWS_TEXT.pad_token = <pad>,从stoi('<pad> : 1')里取出<pad>的值 dropout=args.dropout, ) # 为embedding层的矩阵赋值为NEWS.vocab.vectors model.embedding.from_pretrained(NEWS_TEXT.vocab.vectors) # 构造训练集迭代器,在这一步将torchtext.data.Example对象中的news_text属性数字化 # 还会对同一个batch内的不够长的句子做pad,pad成batch内最长的句子的长度,但是在batch.news_text里会记录句子真实的长度 bucket_iterator = BucketIterator( train_dataset, batch_size=args.train_batch_size, # batch_size大小 sort_within_batch=True, # batch内排序 shuffle=True, # 2.batch间进行乱序 sort_key=lambda x: len( x.news_text), # 1.按句子长度排序,x代表训练集中的每一行,即一个torchtext.data.Example对象 device=args.device, # 放入GPU里 ) # 2.训练 model.to(args.device) # 损失函数 criterion = nn.CrossEntropyLoss() # 优化器 optimizer = Adam(model.parameters(), lr=args.learning_rate, eps=args.adam_epsilon) # 学习率随epoch改变 scheduler = OneCycleLR(optimizer, max_lr=args.learning_rate * 10, epochs=args.num_train_epochs, steps_per_epoch=len(bucket_iterator)) global_step = 0 # 梯度清零 model.zero_grad() # tqdm(list) 方法可以传入任意一种list # trange(i) 是 tqdm(range(i)) 的简单写法 # 下式左边,等价于tqdm(range(0, 5)) train_trange = trange(0, args.num_train_epochs, desc="Train epoch") for _ in train_trange: epoch_iterator = tqdm(bucket_iterator, desc='Training') # 进度条 # 对每个batch做一个前向传播和反向传播,更新参数 for step, batch in enumerate(epoch_iterator): # for循环结束进度条才为100% model.train() # news_text:所有句子组成一个list[[句子1],[句子2],...],实际是按列是一个句子 # [句子1] = [单词1(单词对应的下标),单词2,单词3,...] # news_text_lengths:所有句子的长度组成一个list news_text, news_text_lengths = batch.news_text # news_text中,每一列是一个数字化后的句子,batch_size是多少,就有多少列 # print(batch.news_text) # # print(len(news_text)) # print(news_text.shape) # # print(len(news_text_lengths)) # print(news_text_lengths) category = batch.category # 标签的list # 前向传播 preds = model(news_text, news_text_lengths) # 计算损失值 loss = criterion(preds, category) # 计算梯度 loss.backward() # loss随每次batch的变化,写入tensorboard writer.add_scalar('Train/Loss', loss.item(), global_step) # 学习率随每次batch的变化,写入tensorboard writer.add_scalar('Train/lr', scheduler.get_last_lr()[0], global_step) # NOTE: Update model, optimizer should update before scheduler # 更新参数 optimizer.step() # 更新学习率 scheduler.step() # 记录用过多少个batch进行参数更新了 global_step += 1 # 评估 # 每50轮评估一次 if args.logging_steps > 0 and global_step % args.logging_steps == 0: # 返回损失值,精准率,召回率,f1_score的字典 results = evaluate(args, model, CATEGORY.vocab, NEWS_TEXT.vocab) # 损失值,精准率,召回率,f1_score随每次batch的变化,写入tensorboard for key, value in results.items(): writer.add_scalar("Eval/{}".format(key), value, global_step) # 每100轮保存一次模型 if args.save_steps > 0 and global_step % args.save_steps == 0: save_model(args, model, optimizer, scheduler, global_step) writer.close()
def tts_train_loop(paths: Paths, model: Tacotron, optimizer, train_set, lr, train_steps, attn_example, warmup_lr=False): device = next( model.parameters()).device # use same device as model parameters for g in optimizer.param_groups: g['lr'] = lr total_iters = len(train_set) epochs = train_steps // total_iters + 1 if warmup_lr: lrs = OneCycleLR(optimizer, lr, total_steps=epochs * total_iters, pct_start=0.5, div_factor=1000, anneal_strategy='cos', final_div_factor=1) for e in range(1, epochs + 1): start = time.time() running_loss = 0 # Perform 1 epoch for i, (x, m, ids, _) in enumerate(train_set, 1): x, m = x.to(device), m.to(device) # Parallelize model onto GPUS using workaround due to python bug if device.type == 'cuda' and torch.cuda.device_count() > 1: m1_hat, m2_hat, attention = data_parallel_workaround( model, x, m) else: m1_hat, m2_hat, attention = model(x, m) m1_loss = F.l1_loss(m1_hat, m) m2_loss = F.l1_loss(m2_hat, m) loss = m1_loss + m2_loss optimizer.zero_grad() loss.backward() if hp.tts_clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), hp.tts_clip_grad_norm).item() if np.isnan(grad_norm): print('grad_norm was NaN!') optimizer.step() if warmup_lr: lrs.step() running_loss += loss.item() avg_loss = running_loss / i speed = i / (time.time() - start) step = model.get_step() k = step // 1000 if step % hp.tts_checkpoint_every == 0: ckpt_name = f'taco_step{k}K' save_checkpoint('tts', paths, model, optimizer, name=ckpt_name, is_silent=True) if attn_example in ids: idx = ids.index(attn_example) save_attention(np_now(attention[idx][:, :160]), paths.tts_attention / f'{step}') save_spectrogram(np_now(m2_hat[idx]), paths.tts_mel_plot / f'{step}', 600) msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:#.4} | {speed:#.2} steps/s | Step: {k}k | ' stream(msg) # Must save latest optimizer state to ensure that resuming training # doesn't produce artifacts save_checkpoint('tts', paths, model, optimizer, is_silent=True) model.log(paths.tts_log, msg) print(' ')
class OneCycleLRCallback(DefaultPyTorchSchedulerCallback): """ Wraps PyTorch's `OneCycleLR` Scheduler as Callback """ def __init__( self, optimizer, max_lr, total_steps=None, epochs=None, steps_per_epoch=None, pct_start=0.3, anneal_strategy='cos', cycle_momentum=True, base_momentum=0.85, max_momentum=0.95, div_factor=25.0, final_div_factor=10000.0, last_epoch=-1): """ Parameters ---------- optimizer (Optimizer): Wrapped optimizer. max_lr (float or list): Upper learning rate boundaries in the cycle for each parameter group. total_steps (int): The total number of steps in the cycle. Note that if a value is provided here, then it must be inferred by providing a value for epochs and steps_per_epoch. Default: None epochs (int): The number of epochs to train for. This is used along with steps_per_epoch in order to infer the total number of steps in the cycle if a value for total_steps is not provided. Default: None steps_per_epoch (int): The number of steps per epoch to train for. This is used along with epochs in order to infer the total number of steps in the cycle if a value for total_steps is not provided. Default: None pct_start (float): The percentage of the cycle (in number of steps) spent increasing the learning rate. Default: 0.3 anneal_strategy (str): {'cos', 'linear'} Specifies the annealing strategy. Default: 'cos' cycle_momentum (bool): If ``True``, momentum is cycled inversely to learning rate between 'base_momentum' and 'max_momentum'. Default: True base_momentum (float or list): Lower momentum boundaries in the cycle for each parameter group. Note that momentum is cycled inversely to learning rate; at the peak of a cycle, momentum is 'base_momentum' and learning rate is 'max_lr'. Default: 0.85 max_momentum (float or list): Upper momentum boundaries in the cycle for each parameter group. Functionally, it defines the cycle amplitude (max_momentum - base_momentum). Note that momentum is cycled inversely to learning rate; at the start of a cycle, momentum is 'max_momentum' and learning rate is 'base_lr' Default: 0.95 div_factor (float): Determines the initial learning rate via initial_lr = max_lr/div_factor Default: 25 final_div_factor (float): Determines the minimum learning rate via min_lr = initial_lr/final_div_factor Default: 1e4 last_epoch (int): The index of the last batch. This parameter is used when resuming a training job. Since `step()` should be invoked after each batch instead of after each epoch, this number represents the total number of *batches* computed, not the total number of epochs computed. When last_epoch=-1, the schedule is started from the beginning. Default: -1 """ super().__init__() self.scheduler = OneCycleLR( optimizer, max_lr, total_steps, epochs, steps_per_epoch, pct_start, anneal_strategy, cycle_momentum, base_momentum, max_momentum, div_factor, final_div_factor, last_epoch) def at_iter_begin(self, trainer, train, **kwargs): """ Executes a single scheduling step Parameters ---------- trainer : :class:`PyTorchNetworkTrainer` the trainer class, which can be changed kwargs : additional keyword arguments Returns ------- :class:`PyTorchNetworkTrainer` modified trainer """ if train: self.scheduler.step() return {} def at_epoch_end(self, trainer, **kwargs): return {}
class Trainer(): def __init__(self, config, pretrained=True, augmentor=ImgAugTransform()): self.config = config self.model, self.vocab = build_model(config) self.device = config['device'] self.num_iters = config['trainer']['iters'] self.beamsearch = config['predictor']['beamsearch'] self.data_root = config['dataset']['data_root'] self.train_annotation = config['dataset']['train_annotation'] self.valid_annotation = config['dataset']['valid_annotation'] self.train_lmdb = config['dataset']['train_lmdb'] self.valid_lmdb = config['dataset']['valid_lmdb'] self.dataset_name = config['dataset']['name'] self.batch_size = config['trainer']['batch_size'] self.print_every = config['trainer']['print_every'] self.valid_every = config['trainer']['valid_every'] self.image_aug = config['aug']['image_aug'] self.masked_language_model = config['aug']['masked_language_model'] self.metrics = config['trainer']['metrics'] self.is_padding = config['dataset']['is_padding'] self.tensorboard_dir = config['monitor']['log_dir'] if not os.path.exists(self.tensorboard_dir): os.makedirs(self.tensorboard_dir, exist_ok=True) self.writer = SummaryWriter(self.tensorboard_dir) # LOGGER self.logger = Logger(config['monitor']['log_dir']) self.logger.info(config) self.iter = 0 self.best_acc = 0 self.scheduler = None self.is_finetuning = config['trainer']['is_finetuning'] if self.is_finetuning: self.logger.info("Finetuning model ---->") if self.model.seq_modeling == 'crnn': self.optimizer = Adam(lr=0.0001, params=self.model.parameters(), betas=(0.5, 0.999)) else: self.optimizer = AdamW(lr=0.0001, params=self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) else: self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) self.scheduler = OneCycleLR(self.optimizer, total_steps=self.num_iters, **config['optimizer']) if self.model.seq_modeling == 'crnn': self.criterion = torch.nn.CTCLoss(self.vocab.pad, zero_infinity=True) else: self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1) # Pretrained model if config['trainer']['pretrained']: self.load_weights(config['trainer']['pretrained']) self.logger.info("Loaded trained model from: {}".format( config['trainer']['pretrained'])) # Resume elif config['trainer']['resume_from']: self.load_checkpoint(config['trainer']['resume_from']) for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(torch.device(self.device)) self.logger.info("Resume training from {}".format( config['trainer']['resume_from'])) # DATASET transforms = None if self.image_aug: transforms = augmentor train_lmdb_paths = [ os.path.join(self.data_root, lmdb_path) for lmdb_path in self.train_lmdb ] self.train_gen = self.data_gen( lmdb_paths=train_lmdb_paths, data_root=self.data_root, annotation=self.train_annotation, masked_language_model=self.masked_language_model, transform=transforms, is_train=True) if self.valid_annotation: self.valid_gen = self.data_gen( lmdb_paths=[os.path.join(self.data_root, self.valid_lmdb)], data_root=self.data_root, annotation=self.valid_annotation, masked_language_model=False) self.train_losses = [] self.logger.info("Number batch samples of training: %d" % len(self.train_gen)) self.logger.info("Number batch samples of valid: %d" % len(self.valid_gen)) config_savepath = os.path.join(self.tensorboard_dir, "config.yml") if not os.path.exists(config_savepath): self.logger.info("Saving config file at: %s" % config_savepath) Cfg(config).save(config_savepath) def train(self): total_loss = 0 total_loader_time = 0 total_gpu_time = 0 data_iter = iter(self.train_gen) for i in range(self.num_iters): self.iter += 1 start = time.time() try: batch = next(data_iter) except StopIteration: data_iter = iter(self.train_gen) batch = next(data_iter) total_loader_time += time.time() - start start = time.time() # LOSS loss = self.step(batch) total_loss += loss self.train_losses.append((self.iter, loss)) total_gpu_time += time.time() - start if self.iter % self.print_every == 0: info = 'Iter: {:06d} - Train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format( self.iter, total_loss / self.print_every, self.optimizer.param_groups[0]['lr'], total_loader_time, total_gpu_time) lastest_loss = total_loss / self.print_every total_loss = 0 total_loader_time = 0 total_gpu_time = 0 self.logger.info(info) if self.valid_annotation and self.iter % self.valid_every == 0: val_time = time.time() val_loss = self.validate() acc_full_seq, acc_per_char, wer = self.precision(self.metrics) self.logger.info("Iter: {:06d}, start validating".format( self.iter)) info = 'Iter: {:06d} - Valid loss: {:.3f} - Acc full seq: {:.4f} - Acc per char: {:.4f} - WER: {:.4f} - Time: {:.4f}'.format( self.iter, val_loss, acc_full_seq, acc_per_char, wer, time.time() - val_time) self.logger.info(info) if acc_full_seq > self.best_acc: self.save_weights(self.tensorboard_dir + "/best.pt") self.best_acc = acc_full_seq self.logger.info("Iter: {:06d} - Best acc: {:.4f}".format( self.iter, self.best_acc)) filename = 'last.pt' filepath = os.path.join(self.tensorboard_dir, filename) self.logger.info("Save checkpoint %s" % filename) self.save_checkpoint(filepath) log_loss = {'train loss': lastest_loss, 'val loss': val_loss} self.writer.add_scalars('Loss', log_loss, self.iter) self.writer.add_scalar('WER', wer, self.iter) def validate(self): self.model.eval() total_loss = [] with torch.no_grad(): for step, batch in enumerate(self.valid_gen): batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch[ 'img'], batch['tgt_input'], batch['tgt_output'], batch[ 'tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) if self.model.seq_modeling == 'crnn': length = batch['labels_len'] preds_size = torch.autograd.Variable( torch.IntTensor([outputs.size(0)] * self.batch_size)) loss = self.criterion(outputs, tgt_output, preds_size, length) else: outputs = outputs.flatten(0, 1) tgt_output = tgt_output.flatten() loss = self.criterion(outputs, tgt_output) total_loss.append(loss.item()) del outputs del loss total_loss = np.mean(total_loss) self.model.train() return total_loss def predict(self, sample=None): pred_sents = [] actual_sents = [] img_files = [] probs_sents = [] imgs_sents = [] for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)): batch = self.batch_to_device(batch) if self.model.seq_modeling != 'crnn': if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['img'], self.model) prob = None else: translated_sentence, prob = translate( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist()) else: translated_sentence, prob = translate_crnn( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist(), crnn=True) actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist()) pred_sents.extend(pred_sent) actual_sents.extend(actual_sent) imgs_sents.extend(batch['img']) img_files.extend(batch['filenames']) probs_sents.extend(prob) # Visualize in tensorboard if idx == 0: try: num_samples = self.config['monitor']['num_samples'] fig = plt.figure(figsize=(12, 15)) imgs_samples = imgs_sents[:num_samples] preds_samples = pred_sents[:num_samples] actuals_samples = actual_sents[:num_samples] probs_samples = probs_sents[:num_samples] for id_img in range(len(imgs_samples)): img = imgs_samples[id_img] img = img.permute(1, 2, 0) img = img.cpu().detach().numpy() ax = fig.add_subplot(num_samples, 1, id_img + 1, xticks=[], yticks=[]) plt.imshow(img) ax.set_title( "LB: {} \n Pred: {:.4f}-{}".format( actuals_samples[id_img], probs_samples[id_img], preds_samples[id_img]), color=('green' if actuals_samples[id_img] == preds_samples[id_img] else 'red'), fontdict={ 'fontsize': 18, 'fontweight': 'medium' }) self.writer.add_figure('predictions vs. actuals', fig, global_step=self.iter) except Exception as error: print(error) continue if sample != None and len(pred_sents) > sample: break return pred_sents, actual_sents, img_files, probs_sents, imgs_sents def precision(self, sample=None, measure_time=True): t1 = time.time() pred_sents, actual_sents, _, _, _ = self.predict(sample=sample) time_predict = time.time() - t1 sensitive_case = self.config['predictor']['sensitive_case'] acc_full_seq = compute_accuracy(actual_sents, pred_sents, sensitive_case, mode='full_sequence') acc_per_char = compute_accuracy(actual_sents, pred_sents, sensitive_case, mode='per_char') wer = compute_accuracy(actual_sents, pred_sents, sensitive_case, mode='wer') if measure_time: print("Time: {:.4f}".format(time_predict / len(actual_sents))) return acc_full_seq, acc_per_char, wer def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16, save_fig=False): pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample) if errorcase: wrongs = [] for i in range(len(img_files)): if pred_sents[i] != actual_sents[i]: wrongs.append(i) pred_sents = [pred_sents[i] for i in wrongs] actual_sents = [actual_sents[i] for i in wrongs] img_files = [img_files[i] for i in wrongs] probs = [probs[i] for i in wrongs] imgs = [imgs[i] for i in wrongs] img_files = img_files[:sample] fontdict = {'family': fontname, 'size': fontsize} ncols = 5 nrows = int(math.ceil(len(img_files) / ncols)) fig, ax = plt.subplots(nrows, ncols, figsize=(12, 15)) for vis_idx in range(0, len(img_files)): row = vis_idx // ncols col = vis_idx % ncols pred_sent = pred_sents[vis_idx] actual_sent = actual_sents[vis_idx] prob = probs[vis_idx] img = imgs[vis_idx].permute(1, 2, 0).cpu().detach().numpy() ax[row, col].imshow(img) ax[row, col].set_title( "Pred: {: <2} \n Actual: {} \n prob: {:.2f}".format( pred_sent, actual_sent, prob), fontname=fontname, color='r' if pred_sent != actual_sent else 'g') ax[row, col].get_xaxis().set_ticks([]) ax[row, col].get_yaxis().set_ticks([]) plt.subplots_adjust() if save_fig: fig.savefig('vis_prediction.png') plt.show() def log_prediction(self, sample=16, csv_file='model.csv'): pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample) save_predictions(csv_file, pred_sents, actual_sents, img_files) def vis_data(self, sample=20): ncols = 5 nrows = int(math.ceil(sample / ncols)) fig, ax = plt.subplots(nrows, ncols, figsize=(12, 12)) num_plots = 0 for idx, batch in enumerate(self.train_gen): for vis_idx in range(self.batch_size): row = num_plots // ncols col = num_plots % ncols img = batch['img'][vis_idx].numpy().transpose(1, 2, 0) sent = self.vocab.decode( batch['tgt_input'].T[vis_idx].tolist()) ax[row, col].imshow(img) ax[row, col].set_title("Label: {: <2}".format(sent), fontsize=16, color='g') ax[row, col].get_xaxis().set_ticks([]) ax[row, col].get_yaxis().set_ticks([]) num_plots += 1 if num_plots >= sample: plt.subplots_adjust() fig.savefig('vis_dataset.png') return def load_checkpoint(self, filename): checkpoint = torch.load(filename) self.optimizer.load_state_dict(checkpoint['optimizer']) self.model.load_state_dict(checkpoint['state_dict']) self.iter = checkpoint['iter'] self.train_losses = checkpoint['train_losses'] if self.scheduler is not None: self.scheduler.load_state_dict(checkpoint['scheduler']) self.best_acc = checkpoint['best_acc'] def save_checkpoint(self, filename): state = { 'iter': self.iter, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses, 'scheduler': None if self.scheduler is None else self.scheduler.state_dict(), 'best_acc': self.best_acc } path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(state, filename) def load_weights(self, filename): state_dict = torch.load(filename, map_location=torch.device(self.device)) if self.is_checkpoint(state_dict): self.model.load_state_dict(state_dict['state_dict']) else: for name, param in self.model.named_parameters(): if name not in state_dict: print('{} not found'.format(name)) elif state_dict[name].shape != param.shape: print('{} missmatching shape, required {} but found {}'. format(name, param.shape, state_dict[name].shape)) del state_dict[name] self.model.load_state_dict(state_dict, strict=False) def save_weights(self, filename): path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(self.model.state_dict(), filename) def is_checkpoint(self, checkpoint): try: checkpoint['state_dict'] except: return False else: return True def batch_to_device(self, batch): img = batch['img'].to(self.device, non_blocking=True) tgt_input = batch['tgt_input'].to(self.device, non_blocking=True) tgt_output = batch['tgt_output'].to(self.device, non_blocking=True) tgt_padding_mask = batch['tgt_padding_mask'].to(self.device, non_blocking=True) batch = { 'img': img, 'tgt_input': tgt_input, 'tgt_output': tgt_output, 'tgt_padding_mask': tgt_padding_mask, 'filenames': batch['filenames'], 'labels_len': batch['labels_len'] } return batch def data_gen(self, lmdb_paths, data_root, annotation, masked_language_model=True, transform=None, is_train=False): datasets = [] for lmdb_path in lmdb_paths: dataset = OCRDataset( lmdb_path=lmdb_path, root_dir=data_root, annotation_path=annotation, vocab=self.vocab, transform=transform, image_height=self.config['dataset']['image_height'], image_min_width=self.config['dataset']['image_min_width'], image_max_width=self.config['dataset']['image_max_width'], separate=self.config['dataset']['separate'], batch_size=self.batch_size, is_padding=self.is_padding) datasets.append(dataset) if len(self.train_lmdb) > 1: dataset = torch.utils.data.ConcatDataset(datasets) if self.is_padding: sampler = None else: sampler = ClusterRandomSampler(dataset, self.batch_size, True) collate_fn = Collator(masked_language_model) gen = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler, collate_fn=collate_fn, shuffle=is_train, drop_last=self.model.seq_modeling == 'crnn', **self.config['dataloader']) return gen def step(self, batch): self.model.train() batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[ 'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) if self.model.seq_modeling == 'crnn': length = batch['labels_len'] preds_size = torch.autograd.Variable( torch.IntTensor([outputs.size(0)] * self.batch_size)) loss = self.criterion(outputs, tgt_output, preds_size, length) else: outputs = outputs.view( -1, outputs.size(2)) # flatten(0, 1) # B*S x N_class tgt_output = tgt_output.view(-1) # flatten() # B*S loss = self.criterion(outputs, tgt_output) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() if not self.is_finetuning: self.scheduler.step() loss_item = loss.item() return loss_item def count_parameters(self, model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def gen_pseudo_labels(self, outfile=None): pred_sents = [] img_files = [] probs_sents = [] for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)): batch = self.batch_to_device(batch) if self.model.seq_modeling != 'crnn': if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['img'], self.model) prob = None else: translated_sentence, prob = translate( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist()) else: translated_sentence, prob = translate_crnn( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist(), crnn=True) pred_sents.extend(pred_sent) img_files.extend(batch['filenames']) probs_sents.extend(prob) assert len(pred_sents) == len(img_files) and len(img_files) == len( probs_sents) with open(outfile, 'w', encoding='utf-8') as f: for anno in zip(img_files, pred_sents, probs_sents): f.write('||||'.join([anno[0], anno[1], str(float(anno[2]))]) + '\n')
class Cifar10Agent(BaseAgent): def __init__(self, config): super().__init__(config) self.logger.info("TRAINING MODE ACTIVATED!!!") self.config = config self.use_cuda = self.config['use_cuda'] self.visualize_inline = self.config['visualize_inline'] # create network instance self.model = Net() # define data loader self.dataloader = dl(config=self.config) # intitalize classes self.classes = self.dataloader.classes self.id2classes = {i: y for i, y in enumerate(self.classes)} # define loss self.loss = nn.CrossEntropyLoss() #find optim lr and set optimizer self._find_optim_lr() # intialize weight decay self.l1_decay = self.config['l1_decay'] self.l2_decay = self.config['l2_decay'] # initialize step lr self.use_scheduler = self.config['use_scheduler'] if self.use_scheduler: self.scheduler = self.config["scheduler"]["name"] if self.scheduler == "OneCycleLR": self.scheduler = OneCycleLR( self.optimizer, self.config['learning_rate'], steps_per_epoch=len(self.dataloader.train_loader), **self.config["scheduler"]["kwargs"]) else: self.logger.info( "WARNING : OneCycleLr Scheduler was not setup. Re-initializing use_scheduler to False" ) self.use_scheduler = False # initialize Counter self.current_epoch = 0 self.current_iteration = 0 self.best_metric = 0 self.best_epoch = 0 # intitalize lr values list self.lr_list = [] # initialize loss and accuray arrays self.train_losses = [] self.valid_losses = [] self.train_acc = [] self.valid_acc = [] # initialize misclassified data self.misclassified = {} # initialize maximum accuracy self.max_accuracy = 0.0 if not self.use_cuda and torch.cuda.is_available(): self.logger.info( 'WARNING : You have CUDA device, you should probably enable CUDA.' ) # set manual seed self.manual_seed = self.config['seed'] if self.use_cuda: torch.cuda.manual_seed(self.manual_seed) self.device = torch.device('cuda') torch.cuda.set_device(self.config['gpu_device']) self.model = self.model.to(self.device) self.loss = self.loss.to(self.device) self.logger.info("Program will RUN on ****GPU-CUDA****") print_cuda_statistics() else: torch.manual_seed(self.manual_seed) self.device = torch.device('cpu') self.logger.info("Program will RUN on ****CPU****") # summary of network print("****************************") print("**********NETWORK SUMMARY**********") summary(self.model, input_size=tuple(self.config['input_size'])) print(self.model, file=open( os.path.join(self.config["summary_dir"], "model_arch.txt"), "w")) print("****************************") self.stats_file_name = os.path.join(self.config["stats_dir"], self.config["model_stats_file"]) def load_checkpoint(self, file_name): """ Latest Checkpoint loader :param file_name: name of checkpoint file :return: """ file_name = os.path.join(self.config["checkpoint_dir"], file_name) checkpoint = torch.load(file_name, map_location='cpu') self.model = Net() self.optimizer = optim.SGD(self.model.parameters(), lr=self.config['learning_rate'], momentum=self.config['momentum']) self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) is_best = checkpoint["is_best"] self.misclassified = checkpoint['misclassified_data'] def save_checkpoint(self, file_name='checkpoint.pth.tar', is_best=1): """ Checkpoint Saver :param file_name: name of checkpoint file path :param is_best: boolean flag indicating current metrix is best so far :return: """ checkpoint = { 'epoch': self.current_epoch, 'valid_accuracy': self.max_accuracy, 'misclassified_data': self.misclassified, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'is_best': is_best } file_name = os.path.join(self.config["checkpoint_dir"], file_name) torch.save(checkpoint, file_name) def _find_optim_lr(self): """ find optim learning rate to train network :return: """ self.logger.info("FINDING OPTIM LEARNING RATE...") self.optimizer = optim.SGD(self.model.parameters(), lr=1e-7, momentum=self.config['momentum']) lr_finder = LRFinder(self.model, self.optimizer, self.loss, device='cuda') num_iter = (len(self.dataloader.train_loader.dataset) // self.config["batch_size"]) * 5 lr_finder.range_test(self.dataloader.train_loader, end_lr=100, num_iter=num_iter) if self.visualize_inline: lr_finder.plot() history = lr_finder.history optim_lr = history["lr"][np.argmin(history["loss"])] self.logger.info("Learning rate with minimum loss : " + str(optim_lr)) lr_finder.reset() # set optimizer to optim learning rate self.config["learning_rate"] = round(optim_lr, 3) self.logger.info( f"Setting optimizer to optim learning rate : {self.config['learning_rate']}" ) self.optimizer = optim.SGD(self.model.parameters(), lr=self.config["learning_rate"], momentum=self.config['momentum']) def visualize_set(self): """ Visualize Train set :return: """ dataiter = iter(self.dataloader.train_loader) images, labels = dataiter.next() path = os.path.join(self.config["stats_dir"], 'training_images.png') visualize_data(images, self.config['std'], self.config['mean'], 30, self.visualize_inline, labels, self.classes, path=path) def run(self): """ The main operator :return: """ try: self.train() except Exception as e: self.logger.info(e) def train(self): """ Main training iteration :return: """ for epoch in range(1, self.config['epochs'] + 1): for param_group in self.optimizer.param_groups: self.lr_list.append(param_group['lr']) self.logger.info(f"Current lr value = {param_group['lr']}") self.train_one_epoch() self.validate() self.current_epoch += 1 def train_one_epoch(self): """ One epoch of training :return: """ self.model.train() running_loss = 0.0 running_correct = 0 pbar = tqdm(self.dataloader.train_loader) for batch_idx, (data, target) in enumerate(pbar): data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad() output = self.model(data) loss = self.loss(output, target) if self.l1_decay > 0.0: loss += regularize_loss(self.model, loss, self.l1_decay, 1) if self.l2_decay > 0.0: loss += regularize_loss(self.model, loss, self.l2_decay, 2) loss.backward() self.optimizer.step() if self.use_scheduler: self.scheduler.step() _, preds = torch.max(output.data, 1) # calculate running loss and accuracy running_loss += loss.item() running_correct += (preds == target).sum().item() pbar.set_description( desc=f'loss = {loss.item()} batch_id = {batch_idx}') total_loss = running_loss / len(self.dataloader.train_loader.dataset) total_acc = 100. * running_correct / len( self.dataloader.train_loader.dataset) self.train_losses.append(total_loss) self.train_acc.append(total_acc) self.logger.info( f"TRAIN EPOCH : {self.current_epoch}\tLOSS : {total_loss:.4f}\tACC : {total_acc:.4f}" ) def validate(self): """ One cycle of model evaluation :return: """ self.model.eval() running_loss = 0.0 running_correct = 0 with torch.no_grad(): for data, target in self.dataloader.valid_loader: data, target = data.to(self.device), target.to(self.device) output = self.model(data) running_loss += self.loss(output, target).sum().item() pred = output.argmax(dim=1, keepdim=True) running_correct += pred.eq(target.view_as(pred)).sum().item() is_correct = pred.eq(target.view_as(pred)) misclass_idx = (is_correct == 0).nonzero()[:, 0] for idx in misclass_idx: if str(self.current_epoch) not in self.misclassified: self.misclassified[str(self.current_epoch)] = [] self.misclassified[str(self.current_epoch)].append({ "target": target[idx], "pred": pred[idx], "img": data[idx] }) total_loss = running_loss / len( self.dataloader.valid_loader.dataset) total_acc = 100. * running_correct / len( self.dataloader.valid_loader.dataset) if (self.config['save_checkpoint'] and total_acc > self.max_accuracy): self.max_accuracy = total_acc self.best_epoch = self.current_epoch try: self.save_checkpoint() self.logger.info("Saved Best Model") except Exception as e: self.logger.info(e) self.valid_losses.append(total_loss) self.valid_acc.append(total_acc) self.logger.info( f"VALID EPOCH : {self.current_epoch}\tLOSS : {total_loss:.4f}\tACC : {total_acc:.4f}" ) def finalize(self): """ Finalize operations :return: """ self.logger.info( "Please wait while finalizing the operations.. Thank you") result = { "train_loss": self.train_losses, "train_acc": self.train_acc, "valid_loss": self.valid_losses, "valid_acc": self.valid_acc, "lr_list": self.lr_list } with open(self.stats_file_name, "w") as f: json.dump(result, f)
def __call__( self, net: nn.Module, train_iter: DataLoader, validation_iter: Optional[DataLoader] = None, ) -> None: optimizer = Adam(net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) lr_scheduler = OneCycleLR( optimizer, max_lr=self.maximum_learning_rate, steps_per_epoch=self.num_batches_per_epoch, epochs=self.epochs, ) for epoch_no in range(self.epochs): # mark epoch start time tic = time.time() cumm_epoch_loss = 0.0 total = self.num_batches_per_epoch - 1 # training loop with tqdm(train_iter, total=total) as it: for batch_no, data_entry in enumerate(it, start=1): optimizer.zero_grad() inputs = [v.to(self.device) for v in data_entry.values()] output = net(*inputs) if isinstance(output, (list, tuple)): loss = output[0] else: loss = output cumm_epoch_loss += loss.item() avg_epoch_loss = cumm_epoch_loss / batch_no it.set_postfix( { "epoch": f"{epoch_no + 1}/{self.epochs}", "avg_loss": avg_epoch_loss, }, refresh=False, ) loss.backward() if self.clip_gradient is not None: nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient) optimizer.step() lr_scheduler.step() if self.num_batches_per_epoch == batch_no: break it.close() # validation loop if validation_iter is not None: cumm_epoch_loss_val = 0.0 with tqdm(validation_iter, total=total, colour="green") as it: for batch_no, data_entry in enumerate(it, start=1): inputs = [ v.to(self.device) for v in data_entry.values() ] with torch.no_grad(): output = net(*inputs) if isinstance(output, (list, tuple)): loss = output[0] else: loss = output cumm_epoch_loss_val += loss.item() avg_epoch_loss_val = cumm_epoch_loss_val / batch_no it.set_postfix( { "epoch": f"{epoch_no + 1}/{self.epochs}", "avg_loss": avg_epoch_loss, "avg_val_loss": avg_epoch_loss_val, }, refresh=False, ) if self.num_batches_per_epoch == batch_no: break it.close() # mark epoch end time and log time cost of current epoch toc = time.time()
def do_training(ogn, graph, trainloader, lr=1e-3, total_epochs=100, batch_per_epoch=1500, weight_decay=1e-8, l1=1e-2): batch = trainloader.batch_size X = graph.x y = graph.y # # Set up optimizer: init_lr = lr opt = torch.optim.Adam(ogn.parameters(), lr=init_lr, weight_decay=weight_decay) sched = OneCycleLR( opt, max_lr=init_lr, steps_per_epoch=batch_per_epoch, #len(trainloader), epochs=total_epochs, final_div_factor=1e5) all_losses = [] epoch = 0 for epoch in trange(epoch, total_epochs): ogn.cuda() total_loss = 0.0 i = 0 num_items = 0 while i < batch_per_epoch: for subgraph in trainloader(): if i >= batch_per_epoch: break opt.zero_grad() n_offset = len(subgraph.n_id) cur_len = n_offset cur_edge_index = subgraph.blocks[0].edge_index.clone() cur_edge_index[0] += n_offset g = Data(x=torch.cat( (X[subgraph.n_id], X[subgraph.blocks[0].n_id])).cuda(), y=torch.cat((y[subgraph.n_id], y[subgraph.blocks[0].n_id])).cuda(), edge_index=cur_edge_index.cuda()) loss, reg = new_loss(ogn, g, cur_len, regularization=l1) ((loss + reg) / int(cur_len + 1)).backward() opt.step() sched.step() total_loss += loss.item() i += 1 num_items += cur_len cur_loss = total_loss / num_items all_losses.append(cur_loss) print(cur_loss, flush=True) return all_losses
class SegmentationTrainer(): def __init__(self, name, model, train_set, valid_set, test_set, bs, lr, max_lr, loss_func, device): self.device = device self.name = name self.lr = lr self.bs = bs self.loss_function = loss_func self.metrics = compute_per_channel_dice self.train_set = train_set self.valid_set = valid_set self.test_set = test_set self.train_loader = DataLoader(self.train_set, batch_size=bs, shuffle=True, pin_memory=False) self.valid_loader = DataLoader(self.valid_set, batch_size=bs, shuffle=False, pin_memory=False) self.test_loader = DataLoader(self.test_set, batch_size=bs, shuffle=False, pin_memory=False) if model == 'ResidualUNet3D': model = ResidualUNet3D(1, 1, True).to(self.device).float() optimizer = optim.Adam(model.parameters(), lr=lr) self.tmp_optimizer = optim.Adam(model.parameters(), lr=lr) self.model, self.optimizer = amp.initialize(model, optimizer, opt_level='O2') self.max_lr = max_lr self.lrs = [] self.model_state_dicts = [] def fit(self, epochs, print_each_img, use_cycle=False): torch.cuda.empty_cache() self.train_losses = [] self.valid_losses = [] self.train_scores = [] self.valid_scores = [] self.scheduler = OneCycleLR(self.tmp_optimizer, self.max_lr, epochs=epochs, steps_per_epoch=1, div_factor=25.0, final_div_factor=100) for epoch in range(epochs): self.scheduler.step() lr = self.tmp_optimizer.param_groups[0]['lr'] self.lrs.append(lr) del self.tmp_optimizer, self.scheduler gc.collect() for epoch in range(epochs): self.model.train() total_loss = 0 total_score = 0 print('epoch: ' + str(epoch)) if use_cycle: lr = self.lrs[epoch] self.optimizer.param_groups[0]['lr'] = lr else: lr = self.lr print(lr) for index, batch in tqdm(enumerate(self.train_loader), total=len(self.train_loader)): sample_img, sample_mask = batch sample_img = sample_img.to(self.device) sample_mask = sample_mask.to(self.device) predicted_mask = self.model(sample_img) loss = self.loss_function(predicted_mask, sample_mask) # score = self.metrics(predicted_mask,sample_mask) with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() self.optimizer.step() self.optimizer.zero_grad() total_loss += loss.item() # total_score += score.item() if print_each_img: print('batch loss: ' + str(loss.item())) del batch, sample_img, sample_mask, predicted_mask, loss, scaled_loss gc.collect() torch.cuda.empty_cache() print('total_loss: ' + str(total_loss / len(self.train_loader))) self.train_losses.append(total_loss / len(self.train_loader)) # self.train_scores.append(total_score/len(self.train_set)) val_score = self.val() self.save_checkpoint(self.name, epoch, val_score) def val(self): torch.cuda.empty_cache() self.model.eval() total_val_loss = 0 total_val_score = 0 for index, val_batch in tqdm(enumerate(self.valid_loader), total=len(self.valid_loader)): val_sample_img, val_sample_mask = val_batch val_sample_img = val_sample_img.to(self.device) val_sample_mask = val_sample_mask.to(self.device) del val_batch gc.collect() with torch.no_grad(): val_predicted_mask = self.model(val_sample_img) val_loss = self.loss_function(val_predicted_mask, val_sample_mask) val_score = self.metrics(val_predicted_mask, val_sample_mask) total_val_loss += val_loss.item() total_val_score += val_score.item() del val_sample_img, val_sample_mask, val_predicted_mask, val_loss, val_score gc.collect() print('total_valid_score: ' + str(total_val_score / len(self.valid_set))) torch.cuda.empty_cache() self.valid_losses.append(total_val_loss / len(self.valid_loader)) self.valid_scores.append(total_val_score / len(self.valid_loader)) return total_val_score / len(self.valid_loader) def predict(self): self.model.eval() total_test_loss = 0 total_test_score = 0 for index, test_batch in tqdm(enumerate(self.test_loader), total=len(self.test_loader)): test_sample_img, test_sample_mask = test_batch test_sample_img = test_sample_img.to(self.device) test_sample_mask = test_sample_mask.to(self.device) del test_batch gc.collect() with torch.no_grad(): test_predicted_mask = self.model(test_sample_img) test_loss = self.loss_function(test_predicted_mask, test_sample_mask) test_score = self.metrics(test_predicted_mask, test_sample_mask) total_test_loss += test_loss.item() total_test_score += test_score.item() del test_sample_img, test_sample_mask, test_predicted_mask, test_loss, test_score gc.collect() print('test_score: ' + str(total_test_score / len(self.test_loader))) torch.cuda.empty_cache() self.test_score = total_test_score / len(self.test_loader) return total_test_score / len(self.test_loader) def save_checkpoint(self, name, epoch, val_score): if not os.path.exists('./results'): os.mkdir('./results') if not os.path.exists('./results/' + name): os.mkdir('./results/' + name) dill.dump( self, open( './results/' + name + '/epoch_' + str(epoch) + '_val_score=' + str(val_score) + '.pkl', 'wb')) @staticmethod def load_best_checkpoint(name): checkpoints = sorted([ checkpoint for checkpoint in os.listdir('./results/' + name) if checkpoint.startswith('epoch') ]) best_epoch = np.argmax([ float(checkpoint.split('=')[1].split('.')[1][:10]) for checkpoint in checkpoints ]) best_epoch = int(checkpoints[best_epoch].split('_')[1]) print('best_epoch: ', best_epoch) best_checkpoint = [ checkpoint for checkpoint in checkpoints if checkpoint.startswith('epoch_' + str(best_epoch)) ][0] return dill.load( open('./results/' + name + '/' + best_checkpoint, 'rb'))
def train(ox: Oxentiel, env: gym.Env) -> None: """ Trains a policy gradient model with hyperparams from ``ox``. """ # Set shapes and dimensions for use in type hints. dims.RESOLUTION = ox.resolution dims.BATCH = ox.batch_size dims.ACTS = env.action_space.n shapes.OB = env.observation_space.shape # Make the policy object. ac = ActorCritic(shapes.OB[0], ox.hidden_dim, dims.ACTS) # Make optimizers. policy_optimizer = Adam(ac.pi.parameters(), lr=ox.lr) value_optimizer = Adam(ac.v.parameters(), lr=ox.lr) policy_scheduler = OneCycleLR(policy_optimizer, ox.lr, ox.lr_cycle_steps, pct_start=ox.pct_start) value_scheduler = OneCycleLR(value_optimizer, ox.lr, ox.lr_cycle_steps, pct_start=ox.pct_start) # Create a buffer object to store trajectories. rollouts = RolloutStorage(ox.batch_size, shapes.OB) # Get the initial observation. ob: Array[float, shapes.OB] ob = env.reset() oobs = [] co2s = [] mean_co2 = 0 num_oobs = 0 t_start = time.time() for i in range(ox.iterations): # Sample an action from the policy and estimate the value of current state. act: Array[int, ()] val: Array[float, ()] act, val = get_action(ac, ob) # Step the environment to get new observation, reward, done status, and info. next_ob: Array[float, shapes.OB] rew: int done: bool next_ob, rew, done, info = env.step(int(act)) # Get co2 lbs. co2s.append(info["co2"]) oobs.append(info["oob"]) # Add data for a timestep to the buffer. rollouts.add(ob, act, val, rew) # Don't forget to update the observation. ob = next_ob # If we reached a terminal state, or we completed a batch. if done or rollouts.batch_len == ox.batch_size: # Step 1: Compute advantages and critic targets. # Get episode length. ep_len = rollouts.ep_len dims.EP_LEN = ep_len # Retrieve values and rewards for the current episode. vals: Array[float, ep_len] rews: Array[float, ep_len] vals, rews = rollouts.get_episode_values_and_rewards() mean_rew = np.mean(rews) # The last value should be zero if this is the end of an episode. last_val: float = 0.0 if done else vals[-1] # Compute advantages and rewards-to-go. advs: Array[float, ep_len] = get_advantages(ox, rews, vals, last_val) rtgs: Array[float, ep_len] = get_rewards_to_go(ox, rews) # Record the episode length. if done: rollouts.lens.append(len(advs)) rollouts.rets.append(np.sum(rews)) # Reset the environment. ob = env.reset() mean_co2 = sum(co2s) num_oobs = sum([int(oob) for oob in oobs]) co2s = [] oobs = [] # Step 2: Reset vals and rews in buffer and record computed quantities. rollouts.vals[:] = 0 rollouts.rews[:] = 0 # Record advantages and rewards-to-go. j = rollouts.ep_start assert j + ep_len <= ox.batch_size rollouts.advs[j:j + ep_len] = advs rollouts.rtgs[j:j + ep_len] = rtgs rollouts.ep_start = j + ep_len rollouts.ep_len = 0 # If we completed a batch. if rollouts.batch_len == ox.batch_size: # Get batch data from the buffer. obs: Tensor[float, (ox.batch_size, *shapes.OB)] acts: Tensor[int, (ox.batch_size)] obs, acts, advs, rtgs = rollouts.get_batch() # Run a backward pass on the policy (actor). policy_optimizer.zero_grad() policy_loss = get_policy_loss(ac.pi, obs, acts, advs) policy_loss.backward() policy_optimizer.step() policy_scheduler.step() # Run a backward pass on the value function (critic). value_optimizer.zero_grad() value_loss = get_value_loss(ac.v, obs, rtgs) value_loss.backward() value_optimizer.step() value_scheduler.step() # Reset pointers. rollouts.batch_len = 0 rollouts.ep_start = 0 # Print statistics. lr = policy_scheduler.get_lr() print(f"Iteration: {i + 1} | ", end="") print(f"Time: {time.time() - t_start:.5f} | ", end="") print(f"Total co2: {mean_co2:.5f} | ", end="") print(f"Num OOBs: {num_oobs:.5f} | ", end="") print(f"LR: {lr} | ", end="") print(f"Mean reward for current batch: {mean_rew:.5f}") t_start = time.time() rollouts.rets = [] rollouts.lens = [] if i > 0 and i % ox.save_interval == 0: with open(ox.save_path, "wb") as model_file: torch.save(ac, model_file) print("=== saved model ===")
def Generator_NOGAN(self, epochs: int = 1, style_weight: float = 20., content_weight: float = 1.2, recon_weight: float = 10., tv_weight: float = 1e-6, loss: List[str] = ['content_loss']): """Training Generator in NOGAN manner (Feature Loss only).""" for g in self.optimizer_G.param_groups: g['lr'] = self.G_lr test_img = self.get_test_image() max_lr = self.G_lr * 10. lr_scheduler = OneCycleLR(self.optimizer_G, max_lr=max_lr, steps_per_epoch=len(self.dataloader), epochs=epochs) meter = LossMeters(*loss) total_loss_arr = np.array([]) for epoch in tqdm(range(epochs)): total_losses = 0 meter.reset() for i, (style, smooth, train) in enumerate(self.dataloader, 0): # train = transform(test_img).unsqueeze(0) self.G.zero_grad(set_to_none=self.grad_set_to_none) train = train.to(self.device) generator_output = self.G(train) if 'style_loss' in loss: style = style.to(self.device) style_loss = self.loss.style_loss(generator_output, style) * style_weight else: style_loss = 0. if 'content_loss' in loss: content_loss = self.loss.content_loss( generator_output, train) * content_weight else: content_loss = 0. if 'recon_loss' in loss: recon_loss = self.loss.reconstruction_loss( generator_output, train) * recon_weight else: recon_loss = 0. if 'tv_loss' in loss: tv_loss = self.loss.tv_loss(generator_output) * tv_weight else: tv_loss = 0. total_loss = content_loss + tv_loss + recon_loss + style_loss if self.fp16: with amp.scale_loss(total_loss, self.optimizer_G) as scaled_loss: scaled_loss.backward() else: total_loss.backward() self.optimizer_G.step() lr_scheduler.step() total_losses += total_loss.detach() loss_dict = { 'content_loss': content_loss, 'style_loss': style_loss, 'recon_loss': recon_loss, 'tv_loss': tv_loss } losses = [loss_dict[loss_type].detach() for loss_type in loss] meter.update(*losses) total_loss_arr = np.append(total_loss_arr, total_losses.item()) self.writer.add_scalars(f'{self.init_time} NOGAN generator losses', meter.as_dict('sum'), epoch) self.write_weights(epoch + 1, write_D=False) self.eval_image(epoch, f'{self.init_time} reconstructed img', test_img) if epoch > 2: fig = plt.figure(figsize=(8, 8)) X = np.arange(len(total_loss_arr)) Y = np.gradient(total_loss_arr) plt.plot(X, Y) thresh = -1.0 plt.axhline(thresh, c='r') plt.title(f"{self.init_time}") self.writer.add_figure(f"{self.init_time}", fig, epoch) if Y[-1] > thresh: break self.save_trial(epoch, f'G_NG_{self.init_time}')
def train( self, base_path: Union[Path, str], learning_rate: float = 0.1, mini_batch_size: int = 32, mini_batch_chunk_size: Optional[int] = None, max_epochs: int = 100, train_with_dev: bool = False, train_with_test: bool = False, monitor_train: bool = False, monitor_test: bool = False, main_evaluation_metric: Tuple[str, str] = ("micro avg", 'f1-score'), scheduler=AnnealOnPlateau, anneal_factor: float = 0.5, patience: int = 3, min_learning_rate: float = 0.0001, initial_extra_patience: int = 0, optimizer: torch.optim.Optimizer = SGD, cycle_momentum: bool = False, warmup_fraction: float = 0.1, embeddings_storage_mode: str = "cpu", checkpoint: bool = False, save_final_model: bool = True, anneal_with_restarts: bool = False, anneal_with_prestarts: bool = False, anneal_against_dev_loss: bool = False, batch_growth_annealing: bool = False, shuffle: bool = True, param_selection_mode: bool = False, write_weights: bool = False, num_workers: int = 6, sampler=None, use_amp: bool = False, amp_opt_level: str = "O1", eval_on_train_fraction: float = 0.0, eval_on_train_shuffle: bool = False, save_model_each_k_epochs: int = 0, tensorboard_comment: str = '', use_swa: bool = False, use_final_model_for_eval: bool = False, gold_label_dictionary_for_eval: Optional[Dictionary] = None, create_file_logs: bool = True, create_loss_file: bool = True, epoch: int = 0, use_tensorboard: bool = False, tensorboard_log_dir=None, metrics_for_tensorboard=[], optimizer_state_dict: Optional = None, scheduler_state_dict: Optional = None, save_optimizer_state: bool = False, **kwargs, ) -> dict: """ Trains any class that implements the flair.nn.Model interface. :param base_path: Main path to which all output during training is logged and models are saved :param learning_rate: Initial learning rate (or max, if scheduler is OneCycleLR) :param mini_batch_size: Size of mini-batches during training :param mini_batch_chunk_size: If mini-batches are larger than this number, they get broken down into chunks of this size for processing purposes :param max_epochs: Maximum number of epochs to train. Terminates training if this number is surpassed. :param scheduler: The learning rate scheduler to use :param checkpoint: If True, a full checkpoint is saved at end of each epoch :param cycle_momentum: If scheduler is OneCycleLR, whether the scheduler should cycle also the momentum :param anneal_factor: The factor by which the learning rate is annealed :param patience: Patience is the number of epochs with no improvement the Trainer waits until annealing the learning rate :param min_learning_rate: If the learning rate falls below this threshold, training terminates :param warmup_fraction: Fraction of warmup steps if the scheduler is LinearSchedulerWithWarmup :param train_with_dev: If True, the data from dev split is added to the training data :param train_with_test: If True, the data from test split is added to the training data :param monitor_train: If True, training data is evaluated at end of each epoch :param monitor_test: If True, test data is evaluated at end of each epoch :param embeddings_storage_mode: One of 'none' (all embeddings are deleted and freshly recomputed), 'cpu' (embeddings are stored on CPU) or 'gpu' (embeddings are stored on GPU) :param save_final_model: If True, final model is saved :param anneal_with_restarts: If True, the last best model is restored when annealing the learning rate :param shuffle: If True, data is shuffled during training :param param_selection_mode: If True, testing is performed against dev data. Use this mode when doing parameter selection. :param num_workers: Number of workers in your data loader. :param sampler: You can pass a data sampler here for special sampling of data. :param eval_on_train_fraction: the fraction of train data to do the evaluation on, if 0. the evaluation is not performed on fraction of training data, if 'dev' the size is determined from dev set size :param eval_on_train_shuffle: if True the train data fraction is determined on the start of training and kept fixed during training, otherwise it's sampled at beginning of each epoch :param save_model_each_k_epochs: Each k epochs, a model state will be written out. If set to '5', a model will be saved each 5 epochs. Default is 0 which means no model saving. :param main_evaluation_metric: Type of metric to use for best model tracking and learning rate scheduling (if dev data is available, otherwise loss will be used), currently only applicable for text_classification_model :param tensorboard_comment: Comment to use for tensorboard logging :param create_file_logs: If True, the logs will also be stored in a file 'training.log' in the model folder :param create_loss_file: If True, the loss will be writen to a file 'loss.tsv' in the model folder :param optimizer: The optimizer to use (typically SGD or Adam) :param epoch: The starting epoch (normally 0 but could be higher if you continue training model) :param use_tensorboard: If True, writes out tensorboard information :param tensorboard_log_dir: Directory into which tensorboard log files will be written :param metrics_for_tensorboard: List of tuples that specify which metrics (in addition to the main_score) shall be plotted in tensorboard, could be [("macro avg", 'f1-score'), ("macro avg", 'precision')] for example :param kwargs: Other arguments for the Optimizer :return: """ # create a model card for this model with Flair and PyTorch version model_card = {'flair_version': flair.__version__, 'pytorch_version': torch.__version__} # also record Transformers version if library is loaded try: import transformers model_card['transformers_version'] = transformers.__version__ except: pass # remember all parameters used in train() call local_variables = locals() training_parameters = {} for parameter in signature(self.train).parameters: training_parameters[parameter] = local_variables[parameter] model_card['training_parameters'] = training_parameters # add model card to model self.model.model_card = model_card if use_tensorboard: try: from torch.utils.tensorboard import SummaryWriter if tensorboard_log_dir is not None and not os.path.exists(tensorboard_log_dir): os.mkdir(tensorboard_log_dir) writer = SummaryWriter(log_dir=tensorboard_log_dir, comment=tensorboard_comment) log.info(f"tensorboard logging path is {tensorboard_log_dir}") except: log_line(log) log.warning("ATTENTION! PyTorch >= 1.1.0 and pillow are required for TensorBoard support!") log_line(log) use_tensorboard = False pass if use_amp: if sys.version_info < (3, 0): raise RuntimeError("Apex currently only supports Python 3. Aborting.") if amp is None: raise RuntimeError( "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " "to enable mixed-precision training." ) if mini_batch_chunk_size is None: mini_batch_chunk_size = mini_batch_size if learning_rate < min_learning_rate: min_learning_rate = learning_rate / 10 initial_learning_rate = learning_rate # cast string to Path if type(base_path) is str: base_path = Path(base_path) base_path.mkdir(exist_ok=True, parents=True) if create_file_logs: log_handler = add_file_handler(log, base_path / "training.log") else: log_handler = None log_line(log) log.info(f'Model: "{self.model}"') log_line(log) log.info(f'Corpus: "{self.corpus}"') log_line(log) log.info("Parameters:") log.info(f' - learning_rate: "{learning_rate}"') log.info(f' - mini_batch_size: "{mini_batch_size}"') log.info(f' - patience: "{patience}"') log.info(f' - anneal_factor: "{anneal_factor}"') log.info(f' - max_epochs: "{max_epochs}"') log.info(f' - shuffle: "{shuffle}"') log.info(f' - train_with_dev: "{train_with_dev}"') log.info(f' - batch_growth_annealing: "{batch_growth_annealing}"') log_line(log) log.info(f'Model training base path: "{base_path}"') log_line(log) log.info(f"Device: {flair.device}") log_line(log) log.info(f"Embeddings storage mode: {embeddings_storage_mode}") if isinstance(self.model, SequenceTagger) and self.model.weight_dict and self.model.use_crf: log_line(log) log.warning(f'WARNING: Specified class weights will not take effect when using CRF') # check for previously saved best models in the current training folder and delete them self.check_for_and_delete_previous_best_models(base_path) # determine what splits (train, dev, test) to evaluate and log log_train = True if monitor_train else False log_test = True if (not param_selection_mode and self.corpus.test and monitor_test) else False log_dev = False if train_with_dev or not self.corpus.dev else True log_train_part = True if (eval_on_train_fraction == "dev" or eval_on_train_fraction > 0.0) else False if log_train_part: train_part_size = len(self.corpus.dev) if eval_on_train_fraction == "dev" \ else int(len(self.corpus.train) * eval_on_train_fraction) assert train_part_size > 0 if not eval_on_train_shuffle: train_part_indices = list(range(train_part_size)) train_part = torch.utils.data.dataset.Subset(self.corpus.train, train_part_indices) # prepare loss logging file and set up header loss_txt = init_output_file(base_path, "loss.tsv") if create_loss_file else None weight_extractor = WeightExtractor(base_path) # if optimizer class is passed, instantiate: if inspect.isclass(optimizer): optimizer: torch.optim.Optimizer = optimizer(self.model.parameters(), lr=learning_rate, **kwargs) if use_swa: import torchcontrib optimizer = torchcontrib.optim.SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=learning_rate) if use_amp: self.model, optimizer = amp.initialize( self.model, optimizer, opt_level=amp_opt_level ) # load existing optimizer state dictionary if it exists if optimizer_state_dict: optimizer.load_state_dict(optimizer_state_dict) # minimize training loss if training with dev data, else maximize dev score anneal_mode = "min" if train_with_dev or anneal_against_dev_loss else "max" best_validation_score = 100000000000 if train_with_dev or anneal_against_dev_loss else 0. dataset_size = len(self.corpus.train) if train_with_dev: dataset_size += len(self.corpus.dev) # if scheduler is passed as a class, instantiate if inspect.isclass(scheduler): if scheduler == OneCycleLR: scheduler = OneCycleLR(optimizer, max_lr=learning_rate, steps_per_epoch=dataset_size // mini_batch_size + 1, epochs=max_epochs - epoch, # if we load a checkpoint, we have already trained for epoch pct_start=0.0, cycle_momentum=cycle_momentum) elif scheduler == LinearSchedulerWithWarmup: steps_per_epoch = (dataset_size + mini_batch_size - 1) / mini_batch_size num_train_steps = int(steps_per_epoch * max_epochs) num_warmup_steps = int(num_train_steps * warmup_fraction) scheduler = LinearSchedulerWithWarmup(optimizer, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps) else: scheduler = scheduler( optimizer, factor=anneal_factor, patience=patience, initial_extra_patience=initial_extra_patience, mode=anneal_mode, verbose=True, ) # load existing scheduler state dictionary if it exists if scheduler_state_dict: scheduler.load_state_dict(scheduler_state_dict) # update optimizer and scheduler in model card model_card['training_parameters']['optimizer'] = optimizer model_card['training_parameters']['scheduler'] = scheduler if isinstance(scheduler, OneCycleLR) and batch_growth_annealing: raise ValueError("Batch growth with OneCycle policy is not implemented.") train_data = self.corpus.train # if training also uses dev/train data, include in training set if train_with_dev or train_with_test: parts = [self.corpus.train] if train_with_dev: parts.append(self.corpus.dev) if train_with_test: parts.append(self.corpus.test) train_data = ConcatDataset(parts) # initialize sampler if provided if sampler is not None: # init with default values if only class is provided if inspect.isclass(sampler): sampler = sampler() # set dataset to sample from sampler.set_dataset(train_data) shuffle = False dev_score_history = [] dev_loss_history = [] train_loss_history = [] micro_batch_size = mini_batch_chunk_size # At any point you can hit Ctrl + C to break out of training early. try: previous_learning_rate = learning_rate momentum = 0 for group in optimizer.param_groups: if "momentum" in group: momentum = group["momentum"] for epoch in range(epoch + 1, max_epochs + 1): log_line(log) # update epoch in model card self.model.model_card['training_parameters']['epoch'] = epoch if anneal_with_prestarts: last_epoch_model_state_dict = copy.deepcopy(self.model.state_dict()) if eval_on_train_shuffle: train_part_indices = list(range(self.corpus.train)) random.shuffle(train_part_indices) train_part_indices = train_part_indices[:train_part_size] train_part = torch.utils.data.dataset.Subset(self.corpus.train, train_part_indices) # get new learning rate for group in optimizer.param_groups: learning_rate = group["lr"] if learning_rate != previous_learning_rate and batch_growth_annealing: mini_batch_size *= 2 # reload last best model if annealing with restarts is enabled if ( (anneal_with_restarts or anneal_with_prestarts) and learning_rate != previous_learning_rate and os.path.exists(base_path / "best-model.pt") ): if anneal_with_restarts: log.info("resetting to best model") self.model.load_state_dict( self.model.load(base_path / "best-model.pt").state_dict() ) if anneal_with_prestarts: log.info("resetting to pre-best model") self.model.load_state_dict( self.model.load(base_path / "pre-best-model.pt").state_dict() ) previous_learning_rate = learning_rate if use_tensorboard: writer.add_scalar("learning_rate", learning_rate, epoch) # stop training if learning rate becomes too small if ((not isinstance(scheduler, (OneCycleLR, LinearSchedulerWithWarmup)) and learning_rate < min_learning_rate)): log_line(log) log.info("learning rate too small - quitting training!") log_line(log) break batch_loader = DataLoader( train_data, batch_size=mini_batch_size, shuffle=shuffle if epoch > 1 else False, # never shuffle the first epoch num_workers=num_workers, sampler=sampler, ) self.model.train() train_loss: float = 0 seen_batches = 0 total_number_of_batches = len(batch_loader) modulo = max(1, int(total_number_of_batches / 10)) # process mini-batches batch_time = 0 average_over = 0 for batch_no, batch in enumerate(batch_loader): start_time = time.time() # zero the gradients on the model and optimizer self.model.zero_grad() optimizer.zero_grad() # if necessary, make batch_steps batch_steps = [batch] if len(batch) > micro_batch_size: batch_steps = [batch[x: x + micro_batch_size] for x in range(0, len(batch), micro_batch_size)] # forward and backward for batch for batch_step in batch_steps: # forward pass loss = self.model.forward_loss(batch_step) if isinstance(loss, Tuple): average_over += loss[1] loss = loss[0] # Backward if use_amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() train_loss += loss.item() # do the optimizer step torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) optimizer.step() # do the scheduler step if one-cycle or linear decay if isinstance(scheduler, (OneCycleLR, LinearSchedulerWithWarmup)): scheduler.step() # get new learning rate for group in optimizer.param_groups: learning_rate = group["lr"] if "momentum" in group: momentum = group["momentum"] if "betas" in group: momentum, _ = group["betas"] seen_batches += 1 # depending on memory mode, embeddings are moved to CPU, GPU or deleted store_embeddings(batch, embeddings_storage_mode) batch_time += time.time() - start_time if seen_batches % modulo == 0: momentum_info = f' - momentum: {momentum:.4f}' if cycle_momentum else '' intermittent_loss = train_loss / average_over if average_over > 0 else train_loss / seen_batches log.info( f"epoch {epoch} - iter {seen_batches}/{total_number_of_batches} - loss " f"{intermittent_loss:.8f} - samples/sec: {mini_batch_size * modulo / batch_time:.2f}" f" - lr: {learning_rate:.6f}{momentum_info}" ) batch_time = 0 iteration = epoch * total_number_of_batches + batch_no if not param_selection_mode and write_weights: weight_extractor.extract_weights(self.model.state_dict(), iteration) if average_over != 0: train_loss /= average_over self.model.eval() log_line(log) log.info(f"EPOCH {epoch} done: loss {train_loss:.4f} - lr {learning_rate:.7f}") if use_tensorboard: writer.add_scalar("train_loss", train_loss, epoch) # evaluate on train / dev / test split depending on training settings result_line: str = "" if log_train: train_eval_result = self.model.evaluate( self.corpus.train, gold_label_type=self.model.label_type, mini_batch_size=mini_batch_chunk_size, num_workers=num_workers, embedding_storage_mode=embeddings_storage_mode, main_evaluation_metric=main_evaluation_metric, gold_label_dictionary=gold_label_dictionary_for_eval, ) result_line += f"\t{train_eval_result.log_line}" # depending on memory mode, embeddings are moved to CPU, GPU or deleted store_embeddings(self.corpus.train, embeddings_storage_mode) if log_train_part: train_part_eval_result = self.model.evaluate( train_part, gold_label_type=self.model.label_type, mini_batch_size=mini_batch_chunk_size, num_workers=num_workers, embedding_storage_mode=embeddings_storage_mode, main_evaluation_metric=main_evaluation_metric, gold_label_dictionary=gold_label_dictionary_for_eval, ) result_line += f"\t{train_part_eval_result.loss}\t{train_part_eval_result.log_line}" log.info( f"TRAIN_SPLIT : loss {train_part_eval_result.loss} - {main_evaluation_metric[1]} ({main_evaluation_metric[0]}) {round(train_part_eval_result.main_score, 4)}" ) if use_tensorboard: for (metric_class_avg_type, metric_type) in metrics_for_tensorboard: writer.add_scalar( f"train_{metric_class_avg_type}_{metric_type}", train_part_eval_result.classification_report[metric_class_avg_type][metric_type], epoch ) if log_dev: dev_eval_result = self.model.evaluate( self.corpus.dev, gold_label_type=self.model.label_type, mini_batch_size=mini_batch_chunk_size, num_workers=num_workers, out_path=base_path / "dev.tsv", embedding_storage_mode=embeddings_storage_mode, main_evaluation_metric=main_evaluation_metric, gold_label_dictionary=gold_label_dictionary_for_eval, ) result_line += f"\t{dev_eval_result.loss}\t{dev_eval_result.log_line}" log.info( f"DEV : loss {dev_eval_result.loss} - {main_evaluation_metric[1]} ({main_evaluation_metric[0]}) {round(dev_eval_result.main_score, 4)}" ) # calculate scores using dev data if available # append dev score to score history dev_score_history.append(dev_eval_result.main_score) dev_loss_history.append(dev_eval_result.loss) dev_score = dev_eval_result.main_score # depending on memory mode, embeddings are moved to CPU, GPU or deleted store_embeddings(self.corpus.dev, embeddings_storage_mode) if use_tensorboard: writer.add_scalar("dev_loss", dev_eval_result.loss, epoch) writer.add_scalar("dev_score", dev_eval_result.main_score, epoch) for (metric_class_avg_type, metric_type) in metrics_for_tensorboard: writer.add_scalar( f"dev_{metric_class_avg_type}_{metric_type}", dev_eval_result.classification_report[metric_class_avg_type][metric_type], epoch ) if log_test: test_eval_result = self.model.evaluate( self.corpus.test, gold_label_type=self.model.label_type, mini_batch_size=mini_batch_chunk_size, num_workers=num_workers, out_path=base_path / "test.tsv", embedding_storage_mode=embeddings_storage_mode, main_evaluation_metric=main_evaluation_metric, gold_label_dictionary=gold_label_dictionary_for_eval, ) result_line += f"\t{test_eval_result.loss}\t{test_eval_result.log_line}" log.info( f"TEST : loss {test_eval_result.loss} - {main_evaluation_metric[1]} ({main_evaluation_metric[0]}) {round(test_eval_result.main_score, 4)}" ) # depending on memory mode, embeddings are moved to CPU, GPU or deleted store_embeddings(self.corpus.test, embeddings_storage_mode) if use_tensorboard: writer.add_scalar("test_loss", test_eval_result.loss, epoch) writer.add_scalar("test_score", test_eval_result.main_score, epoch) for (metric_class_avg_type, metric_type) in metrics_for_tensorboard: writer.add_scalar( f"test_{metric_class_avg_type}_{metric_type}", test_eval_result.classification_report[metric_class_avg_type][metric_type], epoch ) # determine if this is the best model or if we need to anneal current_epoch_has_best_model_so_far = False # default mode: anneal against dev score if not train_with_dev and not anneal_against_dev_loss: if dev_score > best_validation_score: current_epoch_has_best_model_so_far = True best_validation_score = dev_score if isinstance(scheduler, AnnealOnPlateau): scheduler.step(dev_score, dev_eval_result.loss) # alternative: anneal against dev loss if not train_with_dev and anneal_against_dev_loss: if dev_eval_result.loss < best_validation_score: current_epoch_has_best_model_so_far = True best_validation_score = dev_eval_result.loss if isinstance(scheduler, AnnealOnPlateau): scheduler.step(dev_eval_result.loss) # alternative: anneal against train loss if train_with_dev: if train_loss < best_validation_score: current_epoch_has_best_model_so_far = True best_validation_score = train_loss if isinstance(scheduler, AnnealOnPlateau): scheduler.step(train_loss) train_loss_history.append(train_loss) # determine bad epoch number try: bad_epochs = scheduler.num_bad_epochs except: bad_epochs = 0 for group in optimizer.param_groups: new_learning_rate = group["lr"] if new_learning_rate != previous_learning_rate: bad_epochs = patience + 1 if previous_learning_rate == initial_learning_rate: bad_epochs += initial_extra_patience # log bad epochs log.info(f"BAD EPOCHS (no improvement): {bad_epochs}") if create_loss_file: # output log file with open(loss_txt, "a") as f: # make headers on first epoch if epoch == 1: f.write(f"EPOCH\tTIMESTAMP\tBAD_EPOCHS\tLEARNING_RATE\tTRAIN_LOSS") if log_train: f.write("\tTRAIN_" + "\tTRAIN_".join(train_eval_result.log_header.split("\t"))) if log_train_part: f.write("\tTRAIN_PART_LOSS\tTRAIN_PART_" + "\tTRAIN_PART_".join( train_part_eval_result.log_header.split("\t"))) if log_dev: f.write("\tDEV_LOSS\tDEV_" + "\tDEV_".join(dev_eval_result.log_header.split("\t"))) if log_test: f.write("\tTEST_LOSS\tTEST_" + "\tTEST_".join(test_eval_result.log_header.split("\t"))) f.write( f"\n{epoch}\t{datetime.datetime.now():%H:%M:%S}\t{bad_epochs}\t{learning_rate:.4f}\t{train_loss}" ) f.write(result_line) # if checkpoint is enabled, save model at each epoch if checkpoint and not param_selection_mode: self.model.save(base_path / "checkpoint.pt", checkpoint=True) # Check whether to save best model if ( (not train_with_dev or anneal_with_restarts or anneal_with_prestarts) and not param_selection_mode and current_epoch_has_best_model_so_far and not use_final_model_for_eval ): log.info("saving best model") self.model.save(base_path / "best-model.pt", checkpoint=save_optimizer_state) if anneal_with_prestarts: current_state_dict = self.model.state_dict() self.model.load_state_dict(last_epoch_model_state_dict) self.model.save(base_path / "pre-best-model.pt") self.model.load_state_dict(current_state_dict) if save_model_each_k_epochs > 0 and not epoch % save_model_each_k_epochs: print("saving model of current epoch") model_name = "model_epoch_" + str(epoch) + ".pt" self.model.save(base_path / model_name, checkpoint=save_optimizer_state) if use_swa: optimizer.swap_swa_sgd() # if we do not use dev data for model selection, save final model if save_final_model and not param_selection_mode: self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) except KeyboardInterrupt: log_line(log) log.info("Exiting from training early.") if use_tensorboard: writer.close() if not param_selection_mode: log.info("Saving model ...") self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) log.info("Done.") # test best model if test data is present if self.corpus.test and not train_with_test: final_score = self.final_test( base_path=base_path, eval_mini_batch_size=mini_batch_chunk_size, num_workers=num_workers, main_evaluation_metric=main_evaluation_metric, gold_label_dictionary_for_eval=gold_label_dictionary_for_eval, ) else: final_score = 0 log.info("Test data not provided setting final score to 0") if create_file_logs: log_handler.close() log.removeHandler(log_handler) if use_tensorboard: writer.close() return { "test_score": final_score, "dev_score_history": dev_score_history, "train_loss_history": train_loss_history, "dev_loss_history": dev_loss_history, }
def Discriminator_NOGAN( self, epochs: int = 3, adv_weight: float = 1.0, edge_weight: float = 1.0, loss: List[str] = ['real_adv_loss', 'fake_adv_loss', 'gray_loss']): """https://discuss.pytorch.org/t/scheduling-batch-size-in-dataloader/46443/2""" for g in self.optimizer_D.param_groups: g['lr'] = self.D_lr max_lr = self.D_lr * 10. lr_scheduler = OneCycleLR(self.optimizer_D, max_lr=max_lr, steps_per_epoch=len(self.dataloader), epochs=epochs) meter = LossMeters(*loss) total_loss_arr = np.array([]) if self.init_time is None: self.init_time = datetime.datetime.now().strftime("%H:%M") for epoch in tqdm(range(epochs)): meter.reset() for i, (style, smooth, train) in enumerate(self.dataloader, 0): # train = transform(test_img).unsqueeze(0) self.D.zero_grad(set_to_none=self.grad_set_to_none) train = train.to(self.device) style = style.to(self.device) generator_output = self.G(train) real_adv_loss = self.D(style).view(-1) fake_adv_loss = self.D(generator_output.detach()).view(-1) real_adv_loss = torch.pow(real_adv_loss - 1, 2).mean() * 1.7 * adv_weight fake_adv_loss = torch.pow(fake_adv_loss, 2).mean() * 1.7 * adv_weight gray_train = tr.inv_gray_transform(style) greyscale_output = self.D(gray_train).view(-1) gray_loss = torch.pow(greyscale_output, 2).mean() * 1.7 * adv_weight "According to AnimeGANv2 implementation, every loss is scaled by individual weights and then scaled with adv_weight" "https://github.com/TachibanaYoshino/AnimeGANv2/blob/5946b6afcca5fc28518b75a763c0f561ff5ce3d6/tools/ops.py#L217" total_loss = real_adv_loss + fake_adv_loss + gray_loss if self.fp16: with amp.scale_loss(total_loss, self.optimizer_D) as scaled_loss: scaled_loss.backward() else: total_loss.backward() self.optimizer_D.step() lr_scheduler.step() loss_dict = { 'real_adv_loss': real_adv_loss, 'fake_adv_loss': fake_adv_loss, 'gray_loss': gray_loss } losses = [loss_dict[loss_type].detach() for loss_type in loss] meter.update(*losses) self.writer.add_scalars( f'{self.init_time} NOGAN discriminator loss', meter.as_dict('sum'), epoch) self.writer.flush() if epoch > 2: fig = plt.figure(figsize=(8, 8)) X = np.arange(len(total_loss_arr)) Y = np.gradient(total_loss_arr) plt.plot(X, Y) thresh = -1.0 plt.axhline(thresh, c='r') plt.title(f"{self.init_time}") self.writer.add_figure(f"{self.init_time}", fig, epoch) if Y[-1] > thresh: break
class Trainer(): def __init__(self, config, pretrained=True): self.config = config self.model, self.vocab = build_model(config) self.device = config['device'] self.num_iters = config['trainer']['iters'] self.beamsearch = config['predictor']['beamsearch'] self.data_root = config['dataset']['data_root'] self.train_annotation = config['dataset']['train_annotation'] self.valid_annotation = config['dataset']['valid_annotation'] self.dataset_name = config['dataset']['name'] self.batch_size = config['trainer']['batch_size'] self.print_every = config['trainer']['print_every'] self.valid_every = config['trainer']['valid_every'] self.checkpoint = config['trainer']['checkpoint'] self.export_weights = config['trainer']['export'] self.metrics = config['trainer']['metrics'] logger = config['trainer']['log'] if logger: self.logger = Logger(logger) if pretrained: weight_file = download_weights(**config['pretrain'], quiet=config['quiet']) self.load_weights(weight_file) self.iter = 0 self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) self.scheduler = OneCycleLR(self.optimizer, **config['optimizer']) # self.optimizer = ScheduledOptim( # Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09), # #config['transformer']['d_model'], # 512, # **config['optimizer']) self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1) transforms = ImgAugTransform() self.train_gen = self.data_gen('train_{}'.format(self.dataset_name), self.data_root, self.train_annotation, transform=transforms) if self.valid_annotation: self.valid_gen = self.data_gen( 'valid_{}'.format(self.dataset_name), self.data_root, self.valid_annotation) self.train_losses = [] def train(self): total_loss = 0 total_loader_time = 0 total_gpu_time = 0 best_acc = 0 data_iter = iter(self.train_gen) for i in range(self.num_iters): self.iter += 1 start = time.time() try: batch = next(data_iter) except StopIteration: data_iter = iter(self.train_gen) batch = next(data_iter) total_loader_time += time.time() - start start = time.time() loss = self.step(batch) total_gpu_time += time.time() - start total_loss += loss self.train_losses.append((self.iter, loss)) if self.iter % self.print_every == 0: info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format( self.iter, total_loss / self.print_every, self.optimizer.param_groups[0]['lr'], total_loader_time, total_gpu_time) total_loss = 0 total_loader_time = 0 total_gpu_time = 0 print(info) self.logger.log(info) if self.valid_annotation and self.iter % self.valid_every == 0: val_loss = self.validate() acc_full_seq, acc_per_char = self.precision(self.metrics) info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f}'.format( self.iter, val_loss, acc_full_seq, acc_per_char) print(info) self.logger.log(info) if acc_full_seq > best_acc: self.save_weights(self.export_weights) best_acc = acc_full_seq def validate(self): self.model.eval() total_loss = [] with torch.no_grad(): for step, batch in enumerate(self.valid_gen): batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch[ 'img'], batch['tgt_input'], batch['tgt_output'], batch[ 'tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) outputs = outputs.flatten(0, 1) tgt_output = tgt_output.flatten() loss = self.criterion(outputs, tgt_output) total_loss.append(loss.item()) del outputs del loss total_loss = np.mean(total_loss) self.model.train() return total_loss def predict(self, sample=None): pred_sents = [] actual_sents = [] img_files = [] for batch in self.valid_gen: batch = self.batch_to_device(batch) if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['img'], self.model) else: translated_sentence = translate(batch['img'], self.model) pred_sent = self.vocab.batch_decode(translated_sentence.tolist()) actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist()) img_files.extend(batch['filenames']) pred_sents.extend(pred_sent) actual_sents.extend(actual_sent) if sample != None and len(pred_sents) > sample: break return pred_sents, actual_sents, img_files def precision(self, sample=None): pred_sents, actual_sents, _ = self.predict(sample=sample) acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode='full_sequence') acc_per_char = compute_accuracy(actual_sents, pred_sents, mode='per_char') return acc_full_seq, acc_per_char def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16): pred_sents, actual_sents, img_files = self.predict(sample) if errorcase: wrongs = [] for i in range(len(img_files)): if pred_sents[i] != actual_sents[i]: wrongs.append(i) pred_sents = [pred_sents[i] for i in wrongs] actual_sents = [actual_sents[i] for i in wrongs] img_files = [img_files[i] for i in wrongs] img_files = img_files[:sample] fontdict = {'family': fontname, 'size': fontsize} for vis_idx in range(0, len(img_files)): img_path = img_files[vis_idx] pred_sent = pred_sents[vis_idx] actual_sent = actual_sents[vis_idx] img = Image.open(open(img_path, 'rb')) plt.figure() plt.imshow(img) plt.title('pred: {} - actual: {}'.format(pred_sent, actual_sent), loc='left', fontdict=fontdict) plt.axis('off') plt.show() def visualize_dataset(self, sample=16, fontname='serif'): n = 0 for batch in self.train_gen: for i in range(self.batch_size): img = batch['img'][i].numpy().transpose(1, 2, 0) sent = self.vocab.decode(batch['tgt_input'].T[i].tolist()) plt.figure() plt.title('sent: {}'.format(sent), loc='center', fontname=fontname) plt.imshow(img) plt.axis('off') n += 1 if n >= sample: plt.show() return def load_checkpoint(self, filename): checkpoint = torch.load(filename) optim = ScheduledOptim( Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09), self.config['transformer']['d_model'], **self.config['optimizer']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.model.load_state_dict(checkpoint['state_dict']) self.iter = checkpoint['iter'] self.train_losses = checkpoint['train_losses'] def save_checkpoint(self, filename): state = { 'iter': self.iter, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses } path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(state, filename) def load_weights(self, filename): state_dict = torch.load(filename, map_location=torch.device(self.device)) for name, param in self.model.named_parameters(): if name not in state_dict: print('{} not found'.format(name)) elif state_dict[name].shape != param.shape: print('{} missmatching shape'.format(name)) del state_dict[name] self.model.load_state_dict(state_dict, strict=False) def save_weights(self, filename): path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(self.model.state_dict(), filename) def batch_to_device(self, batch): img = batch['img'].to(self.device, non_blocking=True) tgt_input = batch['tgt_input'].to(self.device, non_blocking=True) tgt_output = batch['tgt_output'].to(self.device, non_blocking=True) tgt_padding_mask = batch['tgt_padding_mask'].to(self.device, non_blocking=True) batch = { 'img': img, 'tgt_input': tgt_input, 'tgt_output': tgt_output, 'tgt_padding_mask': tgt_padding_mask, 'filenames': batch['filenames'] } return batch def data_gen(self, lmdb_path, data_root, annotation, transform=None): dataset = OCRDataset( lmdb_path=lmdb_path, root_dir=data_root, annotation_path=annotation, vocab=self.vocab, transform=transform, image_height=self.config['dataset']['image_height'], image_min_width=self.config['dataset']['image_min_width'], image_max_width=self.config['dataset']['image_max_width']) sampler = ClusterRandomSampler(dataset, self.batch_size, True) gen = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler, collate_fn=collate_fn, shuffle=False, drop_last=False, **self.config['dataloader']) return gen def data_gen_v1(self, lmdb_path, data_root, annotation): data_gen = DataGen( data_root, annotation, self.vocab, 'cpu', image_height=self.config['dataset']['image_height'], image_min_width=self.config['dataset']['image_min_width'], image_max_width=self.config['dataset']['image_max_width']) return data_gen def step(self, batch): self.model.train() batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[ 'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) outputs = outputs.view(-1, outputs.size(2)) #flatten(0, 1) tgt_output = tgt_output.view(-1) #flatten() loss = self.criterion(outputs, tgt_output) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() self.scheduler.step() loss_item = loss.item() return loss_item
constraints2.append(constraint.item()) else: loss = args.w3 * entropy losses2.append(loss.item()) loss.backward() opt2b.step() opt2b.zero_grad() if imgs.shape[1] == 224: imgs = imgs[:, AVIRIS_TO_SENTINEL2, :, :] elif imgs.shape[1] == 12: imgs = imgs[:, SENTINEL2_TO_4B, :, :] else: imgs = None if args.schedule: sched1.step() sched2.step() if epoch % 107 == 0: log.info(f'Saving checkpoint to /tmp/checkpoint.pth') torch.save(model.state_dict(), '/tmp/checkpoint.pth') mean_constraint1 = np.mean(constraints1) mean_constraint2 = np.mean(constraints2) mean_entropy1 = np.mean(entropies1) mean_entropy2 = np.mean(entropies2) mean_loss1 = np.mean(losses1) mean_loss2 = np.mean(losses2) log.info( f'epoch={epoch:<3d} loss={mean_loss1:+1.5f} entropy={mean_entropy1:+1.5f} constraint={mean_constraint1:+1.5f}' )
def train( self, base_path: Union[Path, str], learning_rate: float = 0.1, mini_batch_size: int = 32, mini_batch_chunk_size: int = None, max_epochs: int = 100, scheduler=AnnealOnPlateau, cycle_momentum: bool = False, anneal_factor: float = 0.5, patience: int = 3, initial_extra_patience=0, min_learning_rate: float = 0.0001, train_with_dev: bool = False, train_with_test: bool = False, monitor_train: bool = False, monitor_test: bool = False, embeddings_storage_mode: str = "cpu", checkpoint: bool = False, save_final_model: bool = True, anneal_with_restarts: bool = False, anneal_with_prestarts: bool = False, batch_growth_annealing: bool = False, shuffle: bool = True, param_selection_mode: bool = False, write_weights: bool = False, num_workers: int = 6, sampler=None, use_amp: bool = False, amp_opt_level: str = "O1", eval_on_train_fraction=0.0, eval_on_train_shuffle=False, save_model_at_each_epoch=False, **kwargs, ) -> dict: """ Trains any class that implements the flair.nn.Model interface. :param base_path: Main path to which all output during training is logged and models are saved :param learning_rate: Initial learning rate (or max, if scheduler is OneCycleLR) :param mini_batch_size: Size of mini-batches during training :param mini_batch_chunk_size: If mini-batches are larger than this number, they get broken down into chunks of this size for processing purposes :param max_epochs: Maximum number of epochs to train. Terminates training if this number is surpassed. :param scheduler: The learning rate scheduler to use :param cycle_momentum: If scheduler is OneCycleLR, whether the scheduler should cycle also the momentum :param anneal_factor: The factor by which the learning rate is annealed :param patience: Patience is the number of epochs with no improvement the Trainer waits until annealing the learning rate :param min_learning_rate: If the learning rate falls below this threshold, training terminates :param train_with_dev: If True, training is performed using both train+dev data :param monitor_train: If True, training data is evaluated at end of each epoch :param monitor_test: If True, test data is evaluated at end of each epoch :param embeddings_storage_mode: One of 'none' (all embeddings are deleted and freshly recomputed), 'cpu' (embeddings are stored on CPU) or 'gpu' (embeddings are stored on GPU) :param checkpoint: If True, a full checkpoint is saved at end of each epoch :param save_final_model: If True, final model is saved :param anneal_with_restarts: If True, the last best model is restored when annealing the learning rate :param shuffle: If True, data is shuffled during training :param param_selection_mode: If True, testing is performed against dev data. Use this mode when doing parameter selection. :param num_workers: Number of workers in your data loader. :param sampler: You can pass a data sampler here for special sampling of data. :param eval_on_train_fraction: the fraction of train data to do the evaluation on, if 0. the evaluation is not performed on fraction of training data, if 'dev' the size is determined from dev set size :param eval_on_train_shuffle: if True the train data fraction is determined on the start of training and kept fixed during training, otherwise it's sampled at beginning of each epoch :param save_model_at_each_epoch: If True, at each epoch the thus far trained model will be saved :param kwargs: Other arguments for the Optimizer :return: """ if self.use_tensorboard: try: from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() except: log_line(log) log.warning( "ATTENTION! PyTorch >= 1.1.0 and pillow are required for TensorBoard support!" ) log_line(log) self.use_tensorboard = False pass if use_amp: if sys.version_info < (3, 0): raise RuntimeError( "Apex currently only supports Python 3. Aborting.") if amp is None: raise RuntimeError( "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " "to enable mixed-precision training.") if mini_batch_chunk_size is None: mini_batch_chunk_size = mini_batch_size if learning_rate < min_learning_rate: min_learning_rate = learning_rate / 10 initial_learning_rate = learning_rate # cast string to Path if type(base_path) is str: base_path = Path(base_path) log_handler = add_file_handler(log, base_path / "training.log") log_line(log) log.info(f'Model: "{self.model}"') log_line(log) log.info(f'Corpus: "{self.corpus}"') log_line(log) log.info("Parameters:") log.info(f' - learning_rate: "{learning_rate}"') log.info(f' - mini_batch_size: "{mini_batch_size}"') log.info(f' - patience: "{patience}"') log.info(f' - anneal_factor: "{anneal_factor}"') log.info(f' - max_epochs: "{max_epochs}"') log.info(f' - shuffle: "{shuffle}"') log.info(f' - train_with_dev: "{train_with_dev}"') log.info(f' - batch_growth_annealing: "{batch_growth_annealing}"') log_line(log) log.info(f'Model training base path: "{base_path}"') log_line(log) log.info(f"Device: {flair.device}") log_line(log) log.info(f"Embeddings storage mode: {embeddings_storage_mode}") if isinstance(self.model, SequenceTagger ) and self.model.weight_dict and self.model.use_crf: log_line(log) log.warning( f'WARNING: Specified class weights will not take effect when using CRF' ) # determine what splits (train, dev, test) to evaluate and log log_train = True if monitor_train else False log_test = (True if (not param_selection_mode and self.corpus.test and monitor_test) else False) log_dev = False if train_with_dev or not self.corpus.dev else True log_train_part = (True if (eval_on_train_fraction == "dev" or eval_on_train_fraction > 0.0) else False) if log_train_part: train_part_size = (len( self.corpus.dev) if eval_on_train_fraction == "dev" else int( len(self.corpus.train) * eval_on_train_fraction)) assert train_part_size > 0 if not eval_on_train_shuffle: train_part_indices = list(range(train_part_size)) train_part = torch.utils.data.dataset.Subset( self.corpus.train, train_part_indices) # prepare loss logging file and set up header loss_txt = init_output_file(base_path, "loss.tsv") weight_extractor = WeightExtractor(base_path) optimizer: torch.optim.Optimizer = self.optimizer( self.model.parameters(), lr=learning_rate, **kwargs) if use_amp: self.model, optimizer = amp.initialize(self.model, optimizer, opt_level=amp_opt_level) # minimize training loss if training with dev data, else maximize dev score anneal_mode = "min" if train_with_dev else "max" if scheduler == OneCycleLR: dataset_size = len(self.corpus.train) if train_with_dev: dataset_size += len(self.corpus.dev) lr_scheduler = OneCycleLR( optimizer, max_lr=learning_rate, steps_per_epoch=dataset_size // mini_batch_size + 1, epochs=max_epochs - self. epoch, # if we load a checkpoint, we have already trained for self.epoch pct_start=0.0, cycle_momentum=cycle_momentum) else: lr_scheduler = scheduler( optimizer, factor=anneal_factor, patience=patience, initial_extra_patience=initial_extra_patience, mode=anneal_mode, verbose=True, ) if (isinstance(lr_scheduler, OneCycleLR) and batch_growth_annealing): raise ValueError( "Batch growth with OneCycle policy is not implemented.") train_data = self.corpus.train # if training also uses dev/train data, include in training set if train_with_dev or train_with_test: parts = [self.corpus.train] if train_with_dev: parts.append(self.corpus.dev) if train_with_test: parts.append(self.corpus.test) train_data = ConcatDataset(parts) # initialize sampler if provided if sampler is not None: # init with default values if only class is provided if inspect.isclass(sampler): sampler = sampler() # set dataset to sample from sampler.set_dataset(train_data) shuffle = False dev_score_history = [] dev_loss_history = [] train_loss_history = [] micro_batch_size = mini_batch_chunk_size # At any point you can hit Ctrl + C to break out of training early. try: previous_learning_rate = learning_rate momentum = 0 for group in optimizer.param_groups: if "momentum" in group: momentum = group["momentum"] for self.epoch in range(self.epoch + 1, max_epochs + 1): log_line(log) if anneal_with_prestarts: last_epoch_model_state_dict = copy.deepcopy( self.model.state_dict()) if eval_on_train_shuffle: train_part_indices = list(range(self.corpus.train)) random.shuffle(train_part_indices) train_part_indices = train_part_indices[:train_part_size] train_part = torch.utils.data.dataset.Subset( self.corpus.train, train_part_indices) # get new learning rate for group in optimizer.param_groups: learning_rate = group["lr"] if learning_rate != previous_learning_rate and batch_growth_annealing: mini_batch_size *= 2 # reload last best model if annealing with restarts is enabled if ((anneal_with_restarts or anneal_with_prestarts) and learning_rate != previous_learning_rate and (base_path / "best-model.pt").exists()): if anneal_with_restarts: log.info("resetting to best model") self.model.load_state_dict( self.model.load(base_path / "best-model.pt").state_dict()) if anneal_with_prestarts: log.info("resetting to pre-best model") self.model.load_state_dict( self.model.load(base_path / "pre-best-model.pt").state_dict()) previous_learning_rate = learning_rate # stop training if learning rate becomes too small if (not isinstance(lr_scheduler, OneCycleLR) ) and learning_rate < min_learning_rate: log_line(log) log.info("learning rate too small - quitting training!") log_line(log) break batch_loader = DataLoader( train_data, batch_size=mini_batch_size, shuffle=shuffle, num_workers=num_workers, sampler=sampler, ) self.model.train() train_loss: float = 0 seen_batches = 0 total_number_of_batches = len(batch_loader) modulo = max(1, int(total_number_of_batches / 10)) # process mini-batches batch_time = 0 for batch_no, batch in enumerate(batch_loader): start_time = time.time() # zero the gradients on the model and optimizer self.model.zero_grad() optimizer.zero_grad() # if necessary, make batch_steps batch_steps = [batch] if len(batch) > micro_batch_size: batch_steps = [ batch[x:x + micro_batch_size] for x in range(0, len(batch), micro_batch_size) ] # forward and backward for batch for batch_step in batch_steps: # forward pass loss = self.model.forward_loss(batch_step) # Backward if use_amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # do the optimizer step torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) optimizer.step() # do the scheduler step if one-cycle if isinstance(lr_scheduler, OneCycleLR): lr_scheduler.step() # get new learning rate for group in optimizer.param_groups: learning_rate = group["lr"] if "momentum" in group: momentum = group["momentum"] seen_batches += 1 train_loss += loss.item() # depending on memory mode, embeddings are moved to CPU, GPU or deleted store_embeddings(batch, embeddings_storage_mode) batch_time += time.time() - start_time if seen_batches % modulo == 0: momentum_info = f' - momentum: {momentum:.4f}' if cycle_momentum else '' log.info( f"epoch {self.epoch} - iter {seen_batches}/{total_number_of_batches} - loss " f"{train_loss / seen_batches:.8f} - samples/sec: {mini_batch_size * modulo / batch_time:.2f}" f" - lr: {learning_rate:.6f}{momentum_info}") batch_time = 0 iteration = self.epoch * total_number_of_batches + batch_no if not param_selection_mode and write_weights: weight_extractor.extract_weights( self.model.state_dict(), iteration) train_loss /= seen_batches self.model.eval() log_line(log) log.info( f"EPOCH {self.epoch} done: loss {train_loss:.4f} - lr {learning_rate:.7f}" ) if self.use_tensorboard: writer.add_scalar("train_loss", train_loss, self.epoch) # anneal against train loss if training with dev, otherwise anneal against dev score current_score = train_loss # evaluate on train / dev / test split depending on training settings result_line: str = "" if log_train: train_eval_result, train_loss = self.model.evaluate( self.corpus.train, mini_batch_size=mini_batch_chunk_size, num_workers=num_workers, embedding_storage_mode=embeddings_storage_mode, ) result_line += f"\t{train_eval_result.log_line}" # depending on memory mode, embeddings are moved to CPU, GPU or deleted store_embeddings(self.corpus.train, embeddings_storage_mode) if log_train_part: train_part_eval_result, train_part_loss = self.model.evaluate( train_part, mini_batch_size=mini_batch_chunk_size, num_workers=num_workers, embedding_storage_mode=embeddings_storage_mode, ) result_line += ( f"\t{train_part_loss}\t{train_part_eval_result.log_line}" ) log.info( f"TRAIN_SPLIT : loss {train_part_loss} - score {round(train_part_eval_result.main_score, 4)}" ) if log_dev: dev_eval_result, dev_loss = self.model.evaluate( self.corpus.dev, mini_batch_size=mini_batch_chunk_size, num_workers=num_workers, out_path=base_path / "dev.tsv", embedding_storage_mode=embeddings_storage_mode, ) result_line += f"\t{dev_loss}\t{dev_eval_result.log_line}" log.info( f"DEV : loss {dev_loss} - score {round(dev_eval_result.main_score, 4)}" ) # calculate scores using dev data if available # append dev score to score history dev_score_history.append(dev_eval_result.main_score) dev_loss_history.append(dev_loss.item()) current_score = dev_eval_result.main_score # depending on memory mode, embeddings are moved to CPU, GPU or deleted store_embeddings(self.corpus.dev, embeddings_storage_mode) if self.use_tensorboard: writer.add_scalar("dev_loss", dev_loss, self.epoch) writer.add_scalar("dev_score", dev_eval_result.main_score, self.epoch) if log_test: test_eval_result, test_loss = self.model.evaluate( self.corpus.test, mini_batch_size=mini_batch_chunk_size, num_workers=num_workers, out_path=base_path / "test.tsv", embedding_storage_mode=embeddings_storage_mode, ) result_line += f"\t{test_loss}\t{test_eval_result.log_line}" log.info( f"TEST : loss {test_loss} - score {round(test_eval_result.main_score, 4)}" ) # depending on memory mode, embeddings are moved to CPU, GPU or deleted store_embeddings(self.corpus.test, embeddings_storage_mode) if self.use_tensorboard: writer.add_scalar("test_loss", test_loss, self.epoch) writer.add_scalar("test_score", test_eval_result.main_score, self.epoch) # determine learning rate annealing through scheduler. Use auxiliary metric for AnnealOnPlateau if log_dev and isinstance(lr_scheduler, AnnealOnPlateau): lr_scheduler.step(current_score, dev_loss) elif not isinstance(lr_scheduler, OneCycleLR): lr_scheduler.step(current_score) train_loss_history.append(train_loss) # determine bad epoch number try: bad_epochs = lr_scheduler.num_bad_epochs except: bad_epochs = 0 for group in optimizer.param_groups: new_learning_rate = group["lr"] if new_learning_rate != previous_learning_rate: bad_epochs = patience + 1 if previous_learning_rate == initial_learning_rate: bad_epochs += initial_extra_patience # log bad epochs log.info(f"BAD EPOCHS (no improvement): {bad_epochs}") # output log file with open(loss_txt, "a") as f: # make headers on first epoch if self.epoch == 1: f.write( f"EPOCH\tTIMESTAMP\tBAD_EPOCHS\tLEARNING_RATE\tTRAIN_LOSS" ) if log_train: f.write("\tTRAIN_" + "\tTRAIN_".join( train_eval_result.log_header.split("\t"))) if log_train_part: f.write("\tTRAIN_PART_LOSS\tTRAIN_PART_" + "\tTRAIN_PART_".join( train_part_eval_result.log_header. split("\t"))) if log_dev: f.write("\tDEV_LOSS\tDEV_" + "\tDEV_".join( dev_eval_result.log_header.split("\t"))) if log_test: f.write("\tTEST_LOSS\tTEST_" + "\tTEST_".join( test_eval_result.log_header.split("\t"))) f.write( f"\n{self.epoch}\t{datetime.datetime.now():%H:%M:%S}\t{bad_epochs}\t{learning_rate:.4f}\t{train_loss}" ) f.write(result_line) # if checkpoint is enabled, save model at each epoch if checkpoint and not param_selection_mode: self.save_checkpoint(base_path / "checkpoint.pt") # if we use dev data, remember best model based on dev evaluation score if ((not train_with_dev or anneal_with_restarts or anneal_with_prestarts) and not param_selection_mode and not isinstance(lr_scheduler, OneCycleLR) and current_score == lr_scheduler.best and bad_epochs == 0): print("saving best model") self.model.save(base_path / "best-model.pt") if anneal_with_prestarts: current_state_dict = self.model.state_dict() self.model.load_state_dict(last_epoch_model_state_dict) self.model.save(base_path / "pre-best-model.pt") self.model.load_state_dict(current_state_dict) if save_model_at_each_epoch: print("saving model of current epoch") model_name = "model_epoch_" + str(self.epoch) + ".pt" self.model.save(base_path / model_name) # if we do not use dev data for model selection, save final model if save_final_model and not param_selection_mode: self.model.save(base_path / "final-model.pt") except KeyboardInterrupt: log_line(log) log.info("Exiting from training early.") if self.use_tensorboard: writer.close() if not param_selection_mode: log.info("Saving model ...") self.model.save(base_path / "final-model.pt") log.info("Done.") # test best model if test data is present if self.corpus.test and not train_with_test: final_score = self.final_test(base_path, mini_batch_chunk_size, num_workers) else: final_score = 0 log.info("Test data not provided setting final score to 0") log.removeHandler(log_handler) if self.use_tensorboard: writer.close() return { "test_score": final_score, "dev_score_history": dev_score_history, "train_loss_history": train_loss_history, "dev_loss_history": dev_loss_history, }
def training_function(config, args): # Initialize accelerator accelerator = Accelerator(fp16=args.fp16, cpu=args.cpu) # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs lr = config["lr"] num_epochs = int(config["num_epochs"]) seed = int(config["seed"]) batch_size = int(config["batch_size"]) image_size = config["image_size"] if not isinstance(image_size, (list, tuple)): image_size = (image_size, image_size) # Grab all the image filenames file_names = [os.path.join(args.data_dir, fname) for fname in os.listdir(args.data_dir) if fname.endswith(".jpg")] # Build the label correspondences all_labels = [extract_label(fname) for fname in file_names] id_to_label = list(set(all_labels)) id_to_label.sort() label_to_id = {lbl: i for i, lbl in enumerate(id_to_label)} # Set the seed before splitting the data. np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # Split our filenames between train and validation random_perm = np.random.permutation(len(file_names)) cut = int(0.8 * len(file_names)) train_split = random_perm[:cut] eval_split = random_perm[cut:] # For training we use a simple RandomResizedCrop train_tfm = Compose([RandomResizedCrop(image_size, scale=(0.5, 1.0)), ToTensor()]) train_dataset = PetsDataset( [file_names[i] for i in train_split], image_transform=train_tfm, label_to_id=label_to_id ) # For evaluation, we use a deterministic Resize eval_tfm = Compose([Resize(image_size), ToTensor()]) eval_dataset = PetsDataset([file_names[i] for i in eval_split], image_transform=eval_tfm, label_to_id=label_to_id) # Instantiate dataloaders. train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=4) eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=batch_size, num_workers=4) # Instantiate the model (we build the model here so that the seed also control new weights initialization) model = create_model("resnet50d", pretrained=True, num_classes=len(label_to_id)) # We could avoid this line since the accelerator is set with `device_placement=True` (default value). # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that). model = model.to(accelerator.device) # Freezing the base model for param in model.parameters(): param.requires_grad = False for param in model.get_classifier().parameters(): param.requires_grad = True # We normalize the batches of images to be a bit faster. mean = torch.tensor(model.default_cfg["mean"])[None, :, None, None].to(accelerator.device) std = torch.tensor(model.default_cfg["std"])[None, :, None, None].to(accelerator.device) # Instantiate optimizer optimizer = torch.optim.Adam(params=model.parameters(), lr=lr / 25) # Prepare everything # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the # prepare method. model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( model, optimizer, train_dataloader, eval_dataloader ) # Instantiate learning rate scheduler after preparing the training dataloader as the prepare method # may change its length. lr_scheduler = OneCycleLR(optimizer=optimizer, max_lr=lr, epochs=num_epochs, steps_per_epoch=len(train_dataloader)) # Now we train the model for epoch in range(num_epochs): model.train() for step, batch in enumerate(train_dataloader): # We could avoid this line since we set the accelerator with `device_placement=True`. batch = {k: v.to(accelerator.device) for k, v in batch.items()} inputs = (batch["image"] - mean) / std outputs = model(inputs) loss = torch.nn.functional.cross_entropy(outputs, batch["label"]) accelerator.backward(loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad() model.eval() accurate = 0 num_elems = 0 for step, batch in enumerate(eval_dataloader): # We could avoid this line since we set the accelerator with `device_placement=True`. batch = {k: v.to(accelerator.device) for k, v in batch.items()} inputs = (batch["image"] - mean) / std with torch.no_grad(): outputs = model(inputs) predictions = outputs.argmax(dim=-1) accurate_preds = accelerator.gather(predictions) == accelerator.gather(batch["label"]) num_elems += accurate_preds.shape[0] accurate += accurate_preds.long().sum() eval_metric = accurate.item() / num_elems # Use accelerator.print to print only on the main process. accelerator.print(f"epoch {epoch}: {100 * eval_metric:.2f}")