def train(args, checkpoint, mid_checkpoint_location, final_checkpoint_location, best_checkpoint_location, actfun, curr_seed, outfile_path, filename, fieldnames, curr_sample_size, device, num_params, curr_k=2, curr_p=1, curr_g=1, perm_method='shuffle'): """ Runs training session for a given randomized model :param args: arguments for this job :param checkpoint: current checkpoint :param checkpoint_location: output directory for checkpoints :param actfun: activation function currently being used :param curr_seed: seed being used by current job :param outfile_path: path to save outputs from training session :param fieldnames: column names for output file :param device: reference to CUDA device for GPU support :param num_params: number of parameters in the network :param curr_k: k value for this iteration :param curr_p: p value for this iteration :param curr_g: g value for this iteration :param perm_method: permutation strategy for our network :return: """ resnet_ver = args.resnet_ver resnet_width = args.resnet_width num_epochs = args.num_epochs actfuns_1d = ['relu', 'abs', 'swish', 'leaky_relu', 'tanh'] if actfun in actfuns_1d: curr_k = 1 kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {} if args.one_shot: util.seed_all(curr_seed) model_temp, _ = load_model(args.model, args.dataset, actfun, curr_k, curr_p, curr_g, num_params=num_params, perm_method=perm_method, device=device, resnet_ver=resnet_ver, resnet_width=resnet_width, verbose=args.verbose) util.seed_all(curr_seed) dataset_temp = util.load_dataset( args, args.model, args.dataset, seed=curr_seed, validation=True, batch_size=args.batch_size, train_sample_size=curr_sample_size, kwargs=kwargs) curr_hparams = hparams.get_hparams(args.model, args.dataset, actfun, curr_seed, num_epochs, args.search, args.hp_idx, args.one_shot) optimizer = optim.Adam(model_temp.parameters(), betas=(curr_hparams['beta1'], curr_hparams['beta2']), eps=curr_hparams['eps'], weight_decay=curr_hparams['wd'] ) start_time = time.time() oneshot_fieldnames = fieldnames if args.search else None oneshot_outfile_path = outfile_path if args.search else None lr = util.run_lr_finder( args, model_temp, dataset_temp[0], optimizer, nn.CrossEntropyLoss(), val_loader=dataset_temp[3], show=False, device=device, fieldnames=oneshot_fieldnames, outfile_path=oneshot_outfile_path, hparams=curr_hparams ) curr_hparams = {} print("Time to find LR: {}\n LR found: {:3e}".format(time.time() - start_time, lr)) else: curr_hparams = hparams.get_hparams(args.model, args.dataset, actfun, curr_seed, num_epochs, args.search, args.hp_idx) lr = curr_hparams['max_lr'] criterion = nn.CrossEntropyLoss() model, model_params = load_model(args.model, args.dataset, actfun, curr_k, curr_p, curr_g, num_params=num_params, perm_method=perm_method, device=device, resnet_ver=resnet_ver, resnet_width=resnet_width, verbose=args.verbose) util.seed_all(curr_seed) model.apply(util.weights_init) util.seed_all(curr_seed) dataset = util.load_dataset( args, args.model, args.dataset, seed=curr_seed, validation=args.validation, batch_size=args.batch_size, train_sample_size=curr_sample_size, kwargs=kwargs) loaders = { 'aug_train': dataset[0], 'train': dataset[1], 'aug_eval': dataset[2], 'eval': dataset[3], } sample_size = dataset[4] batch_size = dataset[5] if args.one_shot: optimizer = optim.Adam(model_params) scheduler = OneCycleLR(optimizer, max_lr=lr, epochs=num_epochs, steps_per_epoch=int(math.floor(sample_size / batch_size)), cycle_momentum=False ) else: optimizer = optim.Adam(model_params, betas=(curr_hparams['beta1'], curr_hparams['beta2']), eps=curr_hparams['eps'], weight_decay=curr_hparams['wd'] ) scheduler = OneCycleLR(optimizer, max_lr=curr_hparams['max_lr'], epochs=num_epochs, steps_per_epoch=int(math.floor(sample_size / batch_size)), pct_start=curr_hparams['cycle_peak'], cycle_momentum=False ) epoch = 1 if checkpoint is not None: model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) epoch = checkpoint['epoch'] model.to(device) print("*** LOADED CHECKPOINT ***" "\n{}" "\nSeed: {}" "\nEpoch: {}" "\nActfun: {}" "\nNum Params: {}" "\nSample Size: {}" "\np: {}" "\nk: {}" "\ng: {}" "\nperm_method: {}".format(mid_checkpoint_location, checkpoint['curr_seed'], checkpoint['epoch'], checkpoint['actfun'], checkpoint['num_params'], checkpoint['sample_size'], checkpoint['p'], checkpoint['k'], checkpoint['g'], checkpoint['perm_method'])) util.print_exp_settings(curr_seed, args.dataset, outfile_path, args.model, actfun, util.get_model_params(model), sample_size, batch_size, model.k, model.p, model.g, perm_method, resnet_ver, resnet_width, args.optim, args.validation, curr_hparams) best_val_acc = 0 if args.mix_pre_apex: model, optimizer = amp.initialize(model, optimizer, opt_level="O2") # ---- Start Training while epoch <= num_epochs: if args.check_path != '': torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'curr_seed': curr_seed, 'epoch': epoch, 'actfun': actfun, 'num_params': num_params, 'sample_size': sample_size, 'p': curr_p, 'k': curr_k, 'g': curr_g, 'perm_method': perm_method }, mid_checkpoint_location) util.seed_all((curr_seed * args.num_epochs) + epoch) start_time = time.time() if args.mix_pre: scaler = torch.cuda.amp.GradScaler() # ---- Training model.train() total_train_loss, n, num_correct, num_total = 0, 0, 0, 0 for batch_idx, (x, targetx) in enumerate(loaders['aug_train']): # print(batch_idx) x, targetx = x.to(device), targetx.to(device) optimizer.zero_grad() if args.mix_pre: with torch.cuda.amp.autocast(): output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 scaler.scale(train_loss).backward() scaler.step(optimizer) scaler.update() elif args.mix_pre_apex: output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 with amp.scale_loss(train_loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() else: output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 train_loss.backward() optimizer.step() if args.optim == 'onecycle' or args.optim == 'onecycle_sgd': scheduler.step() _, prediction = torch.max(output.data, 1) num_correct += torch.sum(prediction == targetx.data) num_total += len(prediction) epoch_aug_train_loss = total_train_loss / n epoch_aug_train_acc = num_correct * 1.0 / num_total alpha_primes = [] alphas = [] if model.actfun == 'combinact': for i, layer_alpha_primes in enumerate(model.all_alpha_primes): curr_alpha_primes = torch.mean(layer_alpha_primes, dim=0) curr_alphas = F.softmax(curr_alpha_primes, dim=0).data.tolist() curr_alpha_primes = curr_alpha_primes.tolist() alpha_primes.append(curr_alpha_primes) alphas.append(curr_alphas) model.eval() with torch.no_grad(): total_val_loss, n, num_correct, num_total = 0, 0, 0, 0 for batch_idx, (y, targety) in enumerate(loaders['aug_eval']): y, targety = y.to(device), targety.to(device) output = model(y) val_loss = criterion(output, targety) total_val_loss += val_loss n += 1 _, prediction = torch.max(output.data, 1) num_correct += torch.sum(prediction == targety.data) num_total += len(prediction) epoch_aug_val_loss = total_val_loss / n epoch_aug_val_acc = num_correct * 1.0 / num_total total_val_loss, n, num_correct, num_total = 0, 0, 0, 0 for batch_idx, (y, targety) in enumerate(loaders['eval']): y, targety = y.to(device), targety.to(device) output = model(y) val_loss = criterion(output, targety) total_val_loss += val_loss n += 1 _, prediction = torch.max(output.data, 1) num_correct += torch.sum(prediction == targety.data) num_total += len(prediction) epoch_val_loss = total_val_loss / n epoch_val_acc = num_correct * 1.0 / num_total lr_curr = 0 for param_group in optimizer.param_groups: lr_curr = param_group['lr'] print( " Epoch {}: LR {:1.5f} ||| aug_train_acc {:1.4f} | val_acc {:1.4f}, aug {:1.4f} ||| " "aug_train_loss {:1.4f} | val_loss {:1.4f}, aug {:1.4f} ||| time = {:1.4f}" .format(epoch, lr_curr, epoch_aug_train_acc, epoch_val_acc, epoch_aug_val_acc, epoch_aug_train_loss, epoch_val_loss, epoch_aug_val_loss, (time.time() - start_time)), flush=True ) if args.hp_idx is None: hp_idx = -1 else: hp_idx = args.hp_idx epoch_train_loss = 0 epoch_train_acc = 0 if epoch == num_epochs: with torch.no_grad(): total_train_loss, n, num_correct, num_total = 0, 0, 0, 0 for batch_idx, (x, targetx) in enumerate(loaders['aug_train']): x, targetx = x.to(device), targetx.to(device) output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 _, prediction = torch.max(output.data, 1) num_correct += torch.sum(prediction == targetx.data) num_total += len(prediction) epoch_aug_train_loss = total_train_loss / n epoch_aug_train_acc = num_correct * 1.0 / num_total total_train_loss, n, num_correct, num_total = 0, 0, 0, 0 for batch_idx, (x, targetx) in enumerate(loaders['train']): x, targetx = x.to(device), targetx.to(device) output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 _, prediction = torch.max(output.data, 1) num_correct += torch.sum(prediction == targetx.data) num_total += len(prediction) epoch_train_loss = total_val_loss / n epoch_train_acc = num_correct * 1.0 / num_total # Outputting data to CSV at end of epoch with open(outfile_path, mode='a') as out_file: writer = csv.DictWriter(out_file, fieldnames=fieldnames, lineterminator='\n') writer.writerow({'dataset': args.dataset, 'seed': curr_seed, 'epoch': epoch, 'time': (time.time() - start_time), 'actfun': model.actfun, 'sample_size': sample_size, 'model': args.model, 'batch_size': batch_size, 'alpha_primes': alpha_primes, 'alphas': alphas, 'num_params': util.get_model_params(model), 'var_nparams': args.var_n_params, 'var_nsamples': args.var_n_samples, 'k': curr_k, 'p': curr_p, 'g': curr_g, 'perm_method': perm_method, 'gen_gap': float(epoch_val_loss - epoch_train_loss), 'aug_gen_gap': float(epoch_aug_val_loss - epoch_aug_train_loss), 'resnet_ver': resnet_ver, 'resnet_width': resnet_width, 'epoch_train_loss': float(epoch_train_loss), 'epoch_train_acc': float(epoch_train_acc), 'epoch_aug_train_loss': float(epoch_aug_train_loss), 'epoch_aug_train_acc': float(epoch_aug_train_acc), 'epoch_val_loss': float(epoch_val_loss), 'epoch_val_acc': float(epoch_val_acc), 'epoch_aug_val_loss': float(epoch_aug_val_loss), 'epoch_aug_val_acc': float(epoch_aug_val_acc), 'hp_idx': hp_idx, 'curr_lr': lr_curr, 'found_lr': lr, 'hparams': curr_hparams, 'epochs': num_epochs }) epoch += 1 if args.optim == 'rmsprop': scheduler.step() if args.checkpoints: if epoch_val_acc > best_val_acc: best_val_acc = epoch_val_acc torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'curr_seed': curr_seed, 'epoch': epoch, 'actfun': actfun, 'num_params': num_params, 'sample_size': sample_size, 'p': curr_p, 'k': curr_k, 'g': curr_g, 'perm_method': perm_method }, best_checkpoint_location) torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'curr_seed': curr_seed, 'epoch': epoch, 'actfun': actfun, 'num_params': num_params, 'sample_size': sample_size, 'p': curr_p, 'k': curr_k, 'g': curr_g, 'perm_method': perm_method }, final_checkpoint_location)
class Maskv3Agent: def __init__(self, config): self.config = config # Train on device target_device = config['train']['device'] if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True self.device = target_device else: self.device = "cpu" # Load dataset train_transform = get_yolo_transform(config['dataset']['size'], mode='train') valid_transform = get_yolo_transform(config['dataset']['size'], mode='test') train_dataset = YOLOMaskDataset( csv_file=config['dataset']['train']['csv'], img_dir=config['dataset']['train']['img_root'], mask_dir=config['dataset']['train']['mask_root'], anchors=config['dataset']['anchors'], scales=config['dataset']['scales'], n_classes=config['dataset']['n_classes'], transform=train_transform) valid_dataset = YOLOMaskDataset( csv_file=config['dataset']['valid']['csv'], img_dir=config['dataset']['valid']['img_root'], mask_dir=config['dataset']['valid']['mask_root'], anchors=config['dataset']['anchors'], scales=config['dataset']['scales'], n_classes=config['dataset']['n_classes'], transform=valid_transform) # DataLoader self.train_loader = DataLoader( dataset=train_dataset, batch_size=config['dataloader']['batch_size'], num_workers=config['dataloader']['num_workers'], collate_fn=maskv3_collate_fn, pin_memory=True, shuffle=True, drop_last=False) self.valid_loader = DataLoader( dataset=valid_dataset, batch_size=config['dataloader']['batch_size'], num_workers=config['dataloader']['num_workers'], collate_fn=maskv3_collate_fn, pin_memory=True, shuffle=False, drop_last=False) # Model model = Maskv3( # Detection Branch in_channels=config['model']['in_channels'], num_classes=config['model']['num_classes'], # Prototype Branch num_masks=config['model']['num_masks'], num_features=config['model']['num_features'], ) self.model = model.to(self.device) # Faciliated Anchor boxes with model torch_anchors = torch.tensor(config['dataset']['anchors']) # (3, 3, 2) torch_scales = torch.tensor(config['dataset']['scales']) # (3,) scaled_anchors = ( # (3, 3, 2) torch_anchors * (torch_scales.unsqueeze(1).unsqueeze(1).repeat(1, 3, 2))) self.scaled_anchors = scaled_anchors.to(self.device) # Optimizer self.scaler = torch.cuda.amp.GradScaler() self.optimizer = optim.Adam( params=self.model.parameters(), lr=config['optimizer']['lr'], weight_decay=config['optimizer']['weight_decay'], ) # Scheduler self.scheduler = OneCycleLR( self.optimizer, max_lr=config['optimizer']['lr'], epochs=config['train']['n_epochs'], steps_per_epoch=len(self.train_loader), ) # Loss function self.loss_fn = YOLOMaskLoss(num_classes=config['model']['num_classes'], num_masks=config['model']['num_masks']) # Tensorboard self.logdir = config['train']['logdir'] self.board = SummaryWriter(logdir=config['train']['logdir']) # Training State self.current_epoch = 0 self.current_map = 0 def resume(self): checkpoint_path = osp.join(self.logdir, 'best.pth') checkpoint = torch.load(checkpoint_path) self.model.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) self.current_map = checkpoint['current_map'] self.current_epoch = checkpoint['current_epoch'] print("Restore checkpoint at '{}'".format(self.current_epoch)) def train(self): for epoch in range(self.current_epoch + 1, self.config['train']['n_epochs'] + 1): self.current_epoch = epoch self._train_one_epoch() self._validate() accs = self._check_accuracy() if self.current_epoch < self.config['valid']['when']: self._save_checkpoint() if (self.current_epoch >= self.config['valid']['when'] and self.current_epoch % 5 == 0): mAP50 = self._check_map() if mAP50 > self.current_map: self.current_map = mAP50 self._save_checkpoint() def finalize(self): self._check_map() def _train_one_epoch(self): n_epochs = self.config['train']['n_epochs'] current_epoch = self.current_epoch current_lr = self.optimizer.param_groups[0]['lr'] loop = tqdm(self.train_loader, leave=True, desc=(f"Train Epoch:{current_epoch}/{n_epochs}" f", LR: {current_lr:.5f}")) obj_losses = [] box_losses = [] noobj_losses = [] class_losses = [] total_losses = [] segment_losses = [] self.model.train() for batch_idx, (imgs, masks, targets) in enumerate(loop): # Move device imgs = imgs.to(self.device) # (N, 3, 416, 416) masks = [m.to(self.device) for m in masks] # (nM_g, H, W) target_s1 = targets[0].to(self.device) # (N, 3, 13, 13, 6) target_s2 = targets[1].to(self.device) # (N, 3, 26, 26, 6) target_s3 = targets[2].to(self.device) # (N, 3, 52, 52, 6) # Model prediction with torch.cuda.amp.autocast(): outs, prototypes = self.model(imgs) s1_loss = self.loss_fn( outs[0], target_s1, self.scaled_anchors[0], # Detection Branch prototypes, masks, # Prototype Branch ) s2_loss = self.loss_fn( outs[1], target_s2, self.scaled_anchors[1], # Detection Branch prototypes, masks, # Prototype Branch ) s3_loss = self.loss_fn( outs[2], target_s3, self.scaled_anchors[2], # Detection Branch prototypes, masks, # Prototype Branch ) # Aggregate loss obj_loss = s1_loss['obj_loss'] + s2_loss['obj_loss'] + s3_loss[ 'obj_loss'] box_loss = s1_loss['box_loss'] + s2_loss['box_loss'] + s3_loss[ 'box_loss'] noobj_loss = s1_loss['noobj_loss'] + s2_loss[ 'noobj_loss'] + s3_loss['noobj_loss'] class_loss = s1_loss['class_loss'] + s2_loss[ 'class_loss'] + s3_loss['class_loss'] segment_loss = s1_loss['segment_loss'] + s2_loss[ 'segment_loss'] + s3_loss['segment_loss'] total_loss = s1_loss['total_loss'] + s2_loss[ 'total_loss'] + s3_loss['total_loss'] # Moving average loss total_losses.append(total_loss.item()) obj_losses.append(obj_loss.item()) noobj_losses.append(noobj_loss.item()) box_losses.append(box_loss.item()) class_losses.append(class_loss.item()) segment_losses.append(segment_loss.item()) # Update Parameters self.optimizer.zero_grad() self.scaler.scale(total_loss).backward() self.scaler.step(self.optimizer) self.scaler.update() self.scheduler.step() # Upadte progress bar mean_total_loss = sum(total_losses) / len(total_losses) mean_obj_loss = sum(obj_losses) / len(obj_losses) mean_noobj_loss = sum(noobj_losses) / len(noobj_losses) mean_box_loss = sum(box_losses) / len(box_losses) mean_class_loss = sum(class_losses) / len(class_losses) mean_segment_loss = sum(segment_losses) / len(segment_losses) loop.set_postfix( loss=mean_total_loss, cls=mean_class_loss, box=mean_box_loss, obj=mean_obj_loss, noobj=mean_noobj_loss, segment=mean_segment_loss, ) # Logging (epoch) epoch_total_loss = sum(total_losses) / len(total_losses) epoch_obj_loss = sum(obj_losses) / len(obj_losses) epoch_noobj_loss = sum(noobj_losses) / len(noobj_losses) epoch_box_loss = sum(box_losses) / len(box_losses) epoch_class_loss = sum(class_losses) / len(class_losses) epoch_segment_loss = sum(segment_losses) / len(segment_losses) self.board.add_scalar('Epoch Train Loss', epoch_total_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Train BOX Loss', epoch_box_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Train OBJ Loss', epoch_obj_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Train NOOBJ Loss', epoch_noobj_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Train CLASS Loss', epoch_class_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Train SEGMENT Loss', epoch_segment_loss, global_step=self.current_epoch) def _validate(self): n_epochs = self.config['train']['n_epochs'] current_epoch = self.current_epoch current_lr = self.optimizer.param_groups[0]['lr'] loop = tqdm(self.valid_loader, leave=True, desc=(f"Valid Epoch:{current_epoch}/{n_epochs}" f", LR: {current_lr:.5f}")) obj_losses = [] box_losses = [] noobj_losses = [] class_losses = [] total_losses = [] segment_losses = [] self.model.eval() for batch_idx, (imgs, masks, targets) in enumerate(loop): # Move device imgs = imgs.to(self.device) # (N, 3, 416, 416) masks = [m.to(self.device) for m in masks] # (nM_g, H, W) target_s1 = targets[0].to(self.device) # (N, 3, 13, 13, 6) target_s2 = targets[1].to(self.device) # (N, 3, 26, 26, 6) target_s3 = targets[2].to(self.device) # (N, 3, 52, 52, 6) # Model Prediction with torch.no_grad(): with torch.cuda.amp.autocast(): outs, prototypes = self.model(imgs) s1_loss = self.loss_fn( outs[0], target_s1, self.scaled_anchors[0], # Detection Branch prototypes, masks, # Prototype Branch ) s2_loss = self.loss_fn( outs[1], target_s2, self.scaled_anchors[1], # Detection Branch prototypes, masks, # Prototype Branch ) s3_loss = self.loss_fn( outs[2], target_s3, self.scaled_anchors[2], # Detection Branch prototypes, masks, # Prototype Branch ) # Aggregate loss obj_loss = s1_loss['obj_loss'] + s2_loss['obj_loss'] + s3_loss[ 'obj_loss'] box_loss = s1_loss['box_loss'] + s2_loss['box_loss'] + s3_loss[ 'box_loss'] noobj_loss = s1_loss['noobj_loss'] + s2_loss[ 'noobj_loss'] + s3_loss['noobj_loss'] class_loss = s1_loss['class_loss'] + s2_loss[ 'class_loss'] + s3_loss['class_loss'] segment_loss = s1_loss['segment_loss'] + s2_loss[ 'segment_loss'] + s3_loss['segment_loss'] total_loss = s1_loss['total_loss'] + s2_loss[ 'total_loss'] + s3_loss['total_loss'] # Moving average loss obj_losses.append(obj_loss.item()) box_losses.append(box_loss.item()) noobj_losses.append(noobj_loss.item()) class_losses.append(class_loss.item()) total_losses.append(total_loss.item()) segment_losses.append(segment_loss.item()) # Upadte progress bar mean_total_loss = sum(total_losses) / len(total_losses) mean_obj_loss = sum(obj_losses) / len(obj_losses) mean_noobj_loss = sum(noobj_losses) / len(noobj_losses) mean_box_loss = sum(box_losses) / len(box_losses) mean_class_loss = sum(class_losses) / len(class_losses) mean_segment_loss = sum(segment_losses) / len(segment_losses) loop.set_postfix( loss=mean_total_loss, cls=mean_class_loss, box=mean_box_loss, obj=mean_obj_loss, noobj=mean_noobj_loss, segment=mean_segment_loss, ) # Logging (epoch) epoch_total_loss = sum(total_losses) / len(total_losses) epoch_obj_loss = sum(obj_losses) / len(obj_losses) epoch_noobj_loss = sum(noobj_losses) / len(noobj_losses) epoch_box_loss = sum(box_losses) / len(box_losses) epoch_class_loss = sum(class_losses) / len(class_losses) epoch_segment_loss = sum(segment_losses) / len(segment_losses) self.board.add_scalar('Epoch Valid Loss', epoch_total_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Valid BOX Loss', epoch_box_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Valid OBJ Loss', epoch_obj_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Valid NOOBJ Loss', epoch_noobj_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Valid CLASS Loss', epoch_class_loss, global_step=self.current_epoch) self.board.add_scalar('Epoch Valid SEGMENT Loss', epoch_segment_loss, global_step=self.current_epoch) def _check_accuracy(self): tot_obj = 0 tot_noobj = 0 correct_obj = 0 correct_noobj = 0 correct_class = 0 self.model.eval() loop = tqdm(self.valid_loader, leave=True, desc=f"Check ACC") for batch_idx, (imgs, masks, targets) in enumerate(loop): batch_size = imgs.size(0) # Move device imgs = imgs.to(self.device) # (N, 3, 416, 416) target_s1 = targets[0].to(self.device) # (N, 3, 13, 13, 6) target_s2 = targets[1].to(self.device) # (N, 3, 26, 26, 6) target_s3 = targets[2].to(self.device) # (N, 3, 52, 52, 6) targets = [target_s1, target_s2, target_s3] # Model Prediction with torch.no_grad(): with torch.cuda.amp.autocast(): outs, prototypes = self.model(imgs) for scale_idx in range(len(outs)): # Get output pred = outs[scale_idx] target = targets[scale_idx] # Get mask obj_mask = target[..., 4] == 1 noobj_mask = target[..., 4] == 0 # Count objects tot_obj += torch.sum(obj_mask) tot_noobj += torch.sum(noobj_mask) # Exception Handling if torch.sum(obj_mask) == 0: obj_pred = torch.sigmoid( pred[..., 4]) > self.config['valid']['conf_threshold'] correct_noobj += torch.sum( obj_pred[noobj_mask] == target[..., 4][noobj_mask]) continue # Count number of correct classified object correct_class += torch.sum((torch.argmax( pred[..., 5:5 + self.config['model']['num_classes']][obj_mask], dim=-1) == target[..., 5][obj_mask])) # Count number of correct objectness & non-objectness obj_pred = torch.sigmoid( pred[..., 4]) > self.config['valid']['conf_threshold'] correct_obj += torch.sum( obj_pred[obj_mask] == target[..., 4][obj_mask]) correct_noobj += torch.sum( obj_pred[noobj_mask] == target[..., 4][noobj_mask]) # Aggregation Result acc_obj = (correct_obj / (tot_obj + 1e-6)) * 100 acc_cls = (correct_class / (tot_obj + 1e-6)) * 100 acc_noobj = (correct_noobj / (tot_noobj + 1e-6)) * 100 accs = { 'cls': acc_cls.item(), 'obj': acc_obj.item(), 'noobj': acc_noobj.item() } print(f"Epoch {self.current_epoch} [Accs]: {accs}") return accs def _check_map(self): sample_idx = 0 all_pred_bboxes = [] all_true_bboxes = [] self.model.eval() loop = tqdm(self.valid_loader, leave=True, desc="Check mAP") for batch_idx, (imgs, masks, targets) in enumerate(loop): batch_size = imgs.size(0) # Move device imgs = imgs.to(self.device) # (N, 3, 416, 416) target_s1 = targets[0].to(self.device) # (N, 3, 13, 13, 6) target_s2 = targets[1].to(self.device) # (N, 3, 26, 26, 6) target_s3 = targets[2].to(self.device) # (N, 3, 52, 52, 6) targets = [target_s1, target_s2, target_s3] # Model Forward with torch.no_grad(): with torch.cuda.amp.autocast(): preds, prototypes = self.model(imgs) # Convert cells to bboxes # ================================================================= true_bboxes = [[] for _ in range(batch_size)] pred_bboxes = [[] for _ in range(batch_size)] for scale_idx, (pred, target) in enumerate(zip(preds, targets)): scale = pred.size(2) anchors = self.scaled_anchors[scale_idx] # (3, 2) anchors = anchors.reshape(1, 3, 1, 1, 2) # (1, 3, 1, 1, 2) # Convert prediction to correct format pred[..., 0:2] = torch.sigmoid(pred[..., 0:2]) # (N, 3, S, S, 2) pred[..., 2:4] = torch.exp( pred[..., 2:4]) * anchors # (N, 3, S, S, 2) pred[..., 4:5] = torch.sigmoid(pred[..., 4:5]) # (N, 3, S, S, 1) pred_cls_probs = F.softmax( pred[..., 5:5 + self.config['model']['num_classes']], dim=-1) # (N, 3, S, S, C) _, indices = torch.max(pred_cls_probs, dim=-1) # (N, 3, S, S) indices = indices.unsqueeze(-1) # (N, 3, S, S, 1) pred = torch.cat([pred[..., :5], indices], dim=-1) # (N, 3, S, S, 6) # Convert coordinate system to normalized format (xywh) pboxes = cells_to_boxes(cells=pred, scale=scale) # (N, 3, S, S, 6) tboxes = cells_to_boxes(cells=target, scale=scale) # (N, 3, S, S, 6) # Filter out bounding boxes from all cells for idx, cell_boxes in enumerate(pboxes): obj_mask = cell_boxes[ ..., 4] > self.config['valid']['conf_threshold'] boxes = cell_boxes[obj_mask] pred_bboxes[idx] += boxes.tolist() # Filter out bounding boxes from all cells for idx, cell_boxes in enumerate(tboxes): obj_mask = cell_boxes[..., 4] > 0.99 boxes = cell_boxes[obj_mask] true_bboxes[idx] += boxes.tolist() # Perform NMS batch-by-batch # ================================================================= for batch_idx in range(batch_size): pbboxes = torch.tensor(pred_bboxes[batch_idx]) tbboxes = torch.tensor(true_bboxes[batch_idx]) # Perform NMS class-by-class for c in range(self.config['model']['num_classes']): # Filter pred boxes of specific class nms_pred_boxes = nms_by_class( target=c, bboxes=pbboxes, iou_threshold=self.config['valid'] ['nms_iou_threshold']) nms_true_boxes = nms_by_class( target=c, bboxes=tbboxes, iou_threshold=self.config['valid'] ['nms_iou_threshold']) all_pred_bboxes.extend([[sample_idx] + box for box in nms_pred_boxes]) all_true_bboxes.extend([[sample_idx] + box for box in nms_true_boxes]) sample_idx += 1 # Compute [email protected] & [email protected] # ================================================================= # The format of the bboxes is (idx, x1, y1, x2, y2, conf, class) all_pred_bboxes = torch.tensor(all_pred_bboxes) # (J, 7) all_true_bboxes = torch.tensor(all_true_bboxes) # (K, 7) eval50 = mean_average_precision( all_pred_bboxes, all_true_bboxes, iou_threshold=0.5, n_classes=self.config['dataset']['n_classes']) eval75 = mean_average_precision( all_pred_bboxes, all_true_bboxes, iou_threshold=0.75, n_classes=self.config['dataset']['n_classes']) print(( f"Epoch {self.current_epoch}:\n" f"\t-[[email protected]]={eval50['mAP']:.3f}, [Recall]={eval50['recall']:.3f}, [Precision]={eval50['precision']:.3f}\n" f"\t-[[email protected]]={eval75['mAP']:.3f}, [Recall]={eval75['recall']:.3f}, [Precision]={eval75['precision']:.3f}\n" )) return eval50['mAP'] def _save_checkpoint(self): checkpoint = { 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict(), 'current_map': self.current_map, 'current_epoch': self.current_epoch } checkpoint_path = osp.join(self.logdir, 'best.pth') torch.save(checkpoint, checkpoint_path) print("Save checkpoint at '{}'".format(checkpoint_path))
class Trainer(): def __init__(self, config, pretrained=True, augmentor=ImgAugTransform()): self.config = config self.model, self.vocab = build_model(config) self.device = config['device'] self.num_iters = config['trainer']['iters'] self.beamsearch = config['predictor']['beamsearch'] self.data_root = config['dataset']['data_root'] self.train_annotation = config['dataset']['train_annotation'] self.valid_annotation = config['dataset']['valid_annotation'] self.train_lmdb = config['dataset']['train_lmdb'] self.valid_lmdb = config['dataset']['valid_lmdb'] self.dataset_name = config['dataset']['name'] self.batch_size = config['trainer']['batch_size'] self.print_every = config['trainer']['print_every'] self.valid_every = config['trainer']['valid_every'] self.image_aug = config['aug']['image_aug'] self.masked_language_model = config['aug']['masked_language_model'] self.metrics = config['trainer']['metrics'] self.is_padding = config['dataset']['is_padding'] self.tensorboard_dir = config['monitor']['log_dir'] if not os.path.exists(self.tensorboard_dir): os.makedirs(self.tensorboard_dir, exist_ok=True) self.writer = SummaryWriter(self.tensorboard_dir) # LOGGER self.logger = Logger(config['monitor']['log_dir']) self.logger.info(config) self.iter = 0 self.best_acc = 0 self.scheduler = None self.is_finetuning = config['trainer']['is_finetuning'] if self.is_finetuning: self.logger.info("Finetuning model ---->") if self.model.seq_modeling == 'crnn': self.optimizer = Adam(lr=0.0001, params=self.model.parameters(), betas=(0.5, 0.999)) else: self.optimizer = AdamW(lr=0.0001, params=self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) else: self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) self.scheduler = OneCycleLR(self.optimizer, total_steps=self.num_iters, **config['optimizer']) if self.model.seq_modeling == 'crnn': self.criterion = torch.nn.CTCLoss(self.vocab.pad, zero_infinity=True) else: self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1) # Pretrained model if config['trainer']['pretrained']: self.load_weights(config['trainer']['pretrained']) self.logger.info("Loaded trained model from: {}".format( config['trainer']['pretrained'])) # Resume elif config['trainer']['resume_from']: self.load_checkpoint(config['trainer']['resume_from']) for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(torch.device(self.device)) self.logger.info("Resume training from {}".format( config['trainer']['resume_from'])) # DATASET transforms = None if self.image_aug: transforms = augmentor train_lmdb_paths = [ os.path.join(self.data_root, lmdb_path) for lmdb_path in self.train_lmdb ] self.train_gen = self.data_gen( lmdb_paths=train_lmdb_paths, data_root=self.data_root, annotation=self.train_annotation, masked_language_model=self.masked_language_model, transform=transforms, is_train=True) if self.valid_annotation: self.valid_gen = self.data_gen( lmdb_paths=[os.path.join(self.data_root, self.valid_lmdb)], data_root=self.data_root, annotation=self.valid_annotation, masked_language_model=False) self.train_losses = [] self.logger.info("Number batch samples of training: %d" % len(self.train_gen)) self.logger.info("Number batch samples of valid: %d" % len(self.valid_gen)) config_savepath = os.path.join(self.tensorboard_dir, "config.yml") if not os.path.exists(config_savepath): self.logger.info("Saving config file at: %s" % config_savepath) Cfg(config).save(config_savepath) def train(self): total_loss = 0 total_loader_time = 0 total_gpu_time = 0 data_iter = iter(self.train_gen) for i in range(self.num_iters): self.iter += 1 start = time.time() try: batch = next(data_iter) except StopIteration: data_iter = iter(self.train_gen) batch = next(data_iter) total_loader_time += time.time() - start start = time.time() # LOSS loss = self.step(batch) total_loss += loss self.train_losses.append((self.iter, loss)) total_gpu_time += time.time() - start if self.iter % self.print_every == 0: info = 'Iter: {:06d} - Train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format( self.iter, total_loss / self.print_every, self.optimizer.param_groups[0]['lr'], total_loader_time, total_gpu_time) lastest_loss = total_loss / self.print_every total_loss = 0 total_loader_time = 0 total_gpu_time = 0 self.logger.info(info) if self.valid_annotation and self.iter % self.valid_every == 0: val_time = time.time() val_loss = self.validate() acc_full_seq, acc_per_char, wer = self.precision(self.metrics) self.logger.info("Iter: {:06d}, start validating".format( self.iter)) info = 'Iter: {:06d} - Valid loss: {:.3f} - Acc full seq: {:.4f} - Acc per char: {:.4f} - WER: {:.4f} - Time: {:.4f}'.format( self.iter, val_loss, acc_full_seq, acc_per_char, wer, time.time() - val_time) self.logger.info(info) if acc_full_seq > self.best_acc: self.save_weights(self.tensorboard_dir + "/best.pt") self.best_acc = acc_full_seq self.logger.info("Iter: {:06d} - Best acc: {:.4f}".format( self.iter, self.best_acc)) filename = 'last.pt' filepath = os.path.join(self.tensorboard_dir, filename) self.logger.info("Save checkpoint %s" % filename) self.save_checkpoint(filepath) log_loss = {'train loss': lastest_loss, 'val loss': val_loss} self.writer.add_scalars('Loss', log_loss, self.iter) self.writer.add_scalar('WER', wer, self.iter) def validate(self): self.model.eval() total_loss = [] with torch.no_grad(): for step, batch in enumerate(self.valid_gen): batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch[ 'img'], batch['tgt_input'], batch['tgt_output'], batch[ 'tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) if self.model.seq_modeling == 'crnn': length = batch['labels_len'] preds_size = torch.autograd.Variable( torch.IntTensor([outputs.size(0)] * self.batch_size)) loss = self.criterion(outputs, tgt_output, preds_size, length) else: outputs = outputs.flatten(0, 1) tgt_output = tgt_output.flatten() loss = self.criterion(outputs, tgt_output) total_loss.append(loss.item()) del outputs del loss total_loss = np.mean(total_loss) self.model.train() return total_loss def predict(self, sample=None): pred_sents = [] actual_sents = [] img_files = [] probs_sents = [] imgs_sents = [] for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)): batch = self.batch_to_device(batch) if self.model.seq_modeling != 'crnn': if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['img'], self.model) prob = None else: translated_sentence, prob = translate( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist()) else: translated_sentence, prob = translate_crnn( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist(), crnn=True) actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist()) pred_sents.extend(pred_sent) actual_sents.extend(actual_sent) imgs_sents.extend(batch['img']) img_files.extend(batch['filenames']) probs_sents.extend(prob) # Visualize in tensorboard if idx == 0: try: num_samples = self.config['monitor']['num_samples'] fig = plt.figure(figsize=(12, 15)) imgs_samples = imgs_sents[:num_samples] preds_samples = pred_sents[:num_samples] actuals_samples = actual_sents[:num_samples] probs_samples = probs_sents[:num_samples] for id_img in range(len(imgs_samples)): img = imgs_samples[id_img] img = img.permute(1, 2, 0) img = img.cpu().detach().numpy() ax = fig.add_subplot(num_samples, 1, id_img + 1, xticks=[], yticks=[]) plt.imshow(img) ax.set_title( "LB: {} \n Pred: {:.4f}-{}".format( actuals_samples[id_img], probs_samples[id_img], preds_samples[id_img]), color=('green' if actuals_samples[id_img] == preds_samples[id_img] else 'red'), fontdict={ 'fontsize': 18, 'fontweight': 'medium' }) self.writer.add_figure('predictions vs. actuals', fig, global_step=self.iter) except Exception as error: print(error) continue if sample != None and len(pred_sents) > sample: break return pred_sents, actual_sents, img_files, probs_sents, imgs_sents def precision(self, sample=None, measure_time=True): t1 = time.time() pred_sents, actual_sents, _, _, _ = self.predict(sample=sample) time_predict = time.time() - t1 sensitive_case = self.config['predictor']['sensitive_case'] acc_full_seq = compute_accuracy(actual_sents, pred_sents, sensitive_case, mode='full_sequence') acc_per_char = compute_accuracy(actual_sents, pred_sents, sensitive_case, mode='per_char') wer = compute_accuracy(actual_sents, pred_sents, sensitive_case, mode='wer') if measure_time: print("Time: {:.4f}".format(time_predict / len(actual_sents))) return acc_full_seq, acc_per_char, wer def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16, save_fig=False): pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample) if errorcase: wrongs = [] for i in range(len(img_files)): if pred_sents[i] != actual_sents[i]: wrongs.append(i) pred_sents = [pred_sents[i] for i in wrongs] actual_sents = [actual_sents[i] for i in wrongs] img_files = [img_files[i] for i in wrongs] probs = [probs[i] for i in wrongs] imgs = [imgs[i] for i in wrongs] img_files = img_files[:sample] fontdict = {'family': fontname, 'size': fontsize} ncols = 5 nrows = int(math.ceil(len(img_files) / ncols)) fig, ax = plt.subplots(nrows, ncols, figsize=(12, 15)) for vis_idx in range(0, len(img_files)): row = vis_idx // ncols col = vis_idx % ncols pred_sent = pred_sents[vis_idx] actual_sent = actual_sents[vis_idx] prob = probs[vis_idx] img = imgs[vis_idx].permute(1, 2, 0).cpu().detach().numpy() ax[row, col].imshow(img) ax[row, col].set_title( "Pred: {: <2} \n Actual: {} \n prob: {:.2f}".format( pred_sent, actual_sent, prob), fontname=fontname, color='r' if pred_sent != actual_sent else 'g') ax[row, col].get_xaxis().set_ticks([]) ax[row, col].get_yaxis().set_ticks([]) plt.subplots_adjust() if save_fig: fig.savefig('vis_prediction.png') plt.show() def log_prediction(self, sample=16, csv_file='model.csv'): pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample) save_predictions(csv_file, pred_sents, actual_sents, img_files) def vis_data(self, sample=20): ncols = 5 nrows = int(math.ceil(sample / ncols)) fig, ax = plt.subplots(nrows, ncols, figsize=(12, 12)) num_plots = 0 for idx, batch in enumerate(self.train_gen): for vis_idx in range(self.batch_size): row = num_plots // ncols col = num_plots % ncols img = batch['img'][vis_idx].numpy().transpose(1, 2, 0) sent = self.vocab.decode( batch['tgt_input'].T[vis_idx].tolist()) ax[row, col].imshow(img) ax[row, col].set_title("Label: {: <2}".format(sent), fontsize=16, color='g') ax[row, col].get_xaxis().set_ticks([]) ax[row, col].get_yaxis().set_ticks([]) num_plots += 1 if num_plots >= sample: plt.subplots_adjust() fig.savefig('vis_dataset.png') return def load_checkpoint(self, filename): checkpoint = torch.load(filename) self.optimizer.load_state_dict(checkpoint['optimizer']) self.model.load_state_dict(checkpoint['state_dict']) self.iter = checkpoint['iter'] self.train_losses = checkpoint['train_losses'] if self.scheduler is not None: self.scheduler.load_state_dict(checkpoint['scheduler']) self.best_acc = checkpoint['best_acc'] def save_checkpoint(self, filename): state = { 'iter': self.iter, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses, 'scheduler': None if self.scheduler is None else self.scheduler.state_dict(), 'best_acc': self.best_acc } path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(state, filename) def load_weights(self, filename): state_dict = torch.load(filename, map_location=torch.device(self.device)) if self.is_checkpoint(state_dict): self.model.load_state_dict(state_dict['state_dict']) else: for name, param in self.model.named_parameters(): if name not in state_dict: print('{} not found'.format(name)) elif state_dict[name].shape != param.shape: print('{} missmatching shape, required {} but found {}'. format(name, param.shape, state_dict[name].shape)) del state_dict[name] self.model.load_state_dict(state_dict, strict=False) def save_weights(self, filename): path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(self.model.state_dict(), filename) def is_checkpoint(self, checkpoint): try: checkpoint['state_dict'] except: return False else: return True def batch_to_device(self, batch): img = batch['img'].to(self.device, non_blocking=True) tgt_input = batch['tgt_input'].to(self.device, non_blocking=True) tgt_output = batch['tgt_output'].to(self.device, non_blocking=True) tgt_padding_mask = batch['tgt_padding_mask'].to(self.device, non_blocking=True) batch = { 'img': img, 'tgt_input': tgt_input, 'tgt_output': tgt_output, 'tgt_padding_mask': tgt_padding_mask, 'filenames': batch['filenames'], 'labels_len': batch['labels_len'] } return batch def data_gen(self, lmdb_paths, data_root, annotation, masked_language_model=True, transform=None, is_train=False): datasets = [] for lmdb_path in lmdb_paths: dataset = OCRDataset( lmdb_path=lmdb_path, root_dir=data_root, annotation_path=annotation, vocab=self.vocab, transform=transform, image_height=self.config['dataset']['image_height'], image_min_width=self.config['dataset']['image_min_width'], image_max_width=self.config['dataset']['image_max_width'], separate=self.config['dataset']['separate'], batch_size=self.batch_size, is_padding=self.is_padding) datasets.append(dataset) if len(self.train_lmdb) > 1: dataset = torch.utils.data.ConcatDataset(datasets) if self.is_padding: sampler = None else: sampler = ClusterRandomSampler(dataset, self.batch_size, True) collate_fn = Collator(masked_language_model) gen = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler, collate_fn=collate_fn, shuffle=is_train, drop_last=self.model.seq_modeling == 'crnn', **self.config['dataloader']) return gen def step(self, batch): self.model.train() batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[ 'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) if self.model.seq_modeling == 'crnn': length = batch['labels_len'] preds_size = torch.autograd.Variable( torch.IntTensor([outputs.size(0)] * self.batch_size)) loss = self.criterion(outputs, tgt_output, preds_size, length) else: outputs = outputs.view( -1, outputs.size(2)) # flatten(0, 1) # B*S x N_class tgt_output = tgt_output.view(-1) # flatten() # B*S loss = self.criterion(outputs, tgt_output) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() if not self.is_finetuning: self.scheduler.step() loss_item = loss.item() return loss_item def count_parameters(self, model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def gen_pseudo_labels(self, outfile=None): pred_sents = [] img_files = [] probs_sents = [] for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)): batch = self.batch_to_device(batch) if self.model.seq_modeling != 'crnn': if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['img'], self.model) prob = None else: translated_sentence, prob = translate( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist()) else: translated_sentence, prob = translate_crnn( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist(), crnn=True) pred_sents.extend(pred_sent) img_files.extend(batch['filenames']) probs_sents.extend(prob) assert len(pred_sents) == len(img_files) and len(img_files) == len( probs_sents) with open(outfile, 'w', encoding='utf-8') as f: for anno in zip(img_files, pred_sents, probs_sents): f.write('||||'.join([anno[0], anno[1], str(float(anno[2]))]) + '\n')
class Learner: def __init__(self, model, train_loader, valid_loader, fold, config, seed): self.config = config self.seed = seed self.device = self.config.device self.train_loader = train_loader self.valid_loader = valid_loader self.model = model.to(self.device) self.fold = fold self.logger = init_logger( config.log_dir, f'train_seed{self.seed}_fold{self.fold}.log') self.tb_logger = init_tb_logger( config.log_dir, f'train_seed{self.seed}_fold{self.fold}') if self.fold == 0: self.log('\n'.join( [f"{k} = {v}" for k, v in self.config.__dict__.items()])) self.criterion = SmoothBCEwLogits(smoothing=self.config.smoothing) self.evaluator = nn.BCEWithLogitsLoss() self.summary_loss = AverageMeter() self.history = {'train': [], 'valid': []} self.optimizer = Adam(self.model.parameters(), lr=config.lr, weight_decay=self.config.weight_decay) self.scheduler = OneCycleLR(optimizer=self.optimizer, pct_start=0.1, div_factor=1e3, max_lr=1e-2, epochs=config.n_epochs, steps_per_epoch=len(train_loader)) self.scaler = GradScaler() if config.fp16 else None self.epoch = 0 self.best_epoch = 0 self.best_loss = np.inf def train_one_epoch(self): self.model.train() self.summary_loss.reset() iters = len(self.train_loader) for step, (g_x, c_x, cate_x, labels, non_labels) in enumerate(self.train_loader): self.optimizer.zero_grad() # self.tb_logger.add_scalar('Train/lr', self.optimizer.param_groups[0]['lr'], # iters * self.epoch + step) labels = labels.to(self.device) non_labels = non_labels.to(self.device) g_x = g_x.to(self.device) c_x = c_x.to(self.device) cate_x = cate_x.to(self.device) batch_size = labels.shape[0] with ExitStack() as stack: if self.config.fp16: auto = stack.enter_context(autocast()) outputs = self.model(g_x, c_x, cate_x) loss = self.criterion(outputs, labels) if self.config.fp16: self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: loss.backward() self.optimizer.step() self.summary_loss.update(loss.item(), batch_size) if self.scheduler.__class__.__name__ != 'ReduceLROnPlateau': self.scheduler.step() self.history['train'].append(self.summary_loss.avg) return self.summary_loss.avg def validation(self): self.model.eval() self.summary_loss.reset() iters = len(self.valid_loader) for step, (g_x, c_x, cate_x, labels, non_labels) in enumerate(self.valid_loader): with torch.no_grad(): labels = labels.to(self.device) g_x = g_x.to(self.device) c_x = c_x.to(self.device) cate_x = cate_x.to(self.device) batch_size = labels.shape[0] outputs = self.model(g_x, c_x, cate_x) loss = self.evaluator(outputs, labels) self.summary_loss.update(loss.detach().item(), batch_size) self.history['valid'].append(self.summary_loss.avg) return self.summary_loss.avg def fit(self, epochs): self.log(f'Start training....') for e in range(epochs): t = time.time() loss = self.train_one_epoch() # self.log(f'[Train] \t Epoch: {self.epoch}, loss: {loss:.6f}, time: {(time.time() - t):.2f}') self.tb_logger.add_scalar('Train/Loss', loss, self.epoch) t = time.time() loss = self.validation() # self.log(f'[Valid] \t Epoch: {self.epoch}, loss: {loss:.6f}, time: {(time.time() - t):.2f}') self.tb_logger.add_scalar('Valid/Loss', loss, self.epoch) self.post_processing(loss) self.epoch += 1 self.log(f'best epoch: {self.best_epoch}, best loss: {self.best_loss}') return self.history def post_processing(self, loss): if loss < self.best_loss: self.best_loss = loss self.best_epoch = self.epoch self.model.eval() torch.save( { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'best_loss': self.best_loss, 'epoch': self.epoch, }, f'{os.path.join(self.config.log_dir, f"{self.config.name}_seed{self.seed}_fold{self.fold}.pth")}' ) self.log(f'best model: {self.epoch} epoch - loss: {loss:.6f}') def load(self, path): checkpoint = torch.load(path, map_location=lambda storage, loc: storage) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.best_loss = checkpoint['best_loss'] self.epoch = checkpoint['epoch'] + 1 def log(self, text): self.logger.info(text)
class Trainer(): def __init__(self, alphabets_, list_ngram): self.vocab = Vocab(alphabets_) self.synthesizer = SynthesizeData(vocab_path="") self.list_ngrams_train, self.list_ngrams_valid = self.train_test_split( list_ngram, test_size=0.1) print("Loaded data!!!") print("Total training samples: ", len(self.list_ngrams_train)) print("Total valid samples: ", len(self.list_ngrams_valid)) INPUT_DIM = self.vocab.__len__() OUTPUT_DIM = self.vocab.__len__() self.device = DEVICE self.num_iters = NUM_ITERS self.beamsearch = BEAM_SEARCH self.batch_size = BATCH_SIZE self.print_every = PRINT_PER_ITER self.valid_every = VALID_PER_ITER self.checkpoint = CHECKPOINT self.export_weights = EXPORT self.metrics = MAX_SAMPLE_VALID logger = LOG if logger: self.logger = Logger(logger) self.iter = 0 self.model = Seq2Seq(input_dim=INPUT_DIM, output_dim=OUTPUT_DIM, encoder_embbeded=ENC_EMB_DIM, decoder_embedded=DEC_EMB_DIM, encoder_hidden=ENC_HID_DIM, decoder_hidden=DEC_HID_DIM, encoder_dropout=ENC_DROPOUT, decoder_dropout=DEC_DROPOUT) self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) self.scheduler = OneCycleLR(self.optimizer, total_steps=self.num_iters, pct_start=PCT_START, max_lr=MAX_LR) self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1) self.train_gen = self.data_gen(self.list_ngrams_train, self.synthesizer, self.vocab, is_train=True) self.valid_gen = self.data_gen(self.list_ngrams_valid, self.synthesizer, self.vocab, is_train=False) self.train_losses = [] # to device self.model.to(self.device) self.criterion.to(self.device) def train_test_split(self, list_phrases, test_size=0.1): list_phrases = list_phrases train_idx = int(len(list_phrases) * (1 - test_size)) list_phrases_train = list_phrases[:train_idx] list_phrases_valid = list_phrases[train_idx:] return list_phrases_train, list_phrases_valid def data_gen(self, list_ngrams_np, synthesizer, vocab, is_train=True): dataset = AutoCorrectDataset(list_ngrams_np, transform_noise=synthesizer, vocab=vocab, maxlen=MAXLEN) shuffle = True if is_train else False gen = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=shuffle, drop_last=False) return gen def step(self, batch): self.model.train() batch = self.batch_to_device(batch) src, tgt = batch['src'], batch['tgt'] src, tgt = src.transpose(1, 0), tgt.transpose( 1, 0) # batch x src_len -> src_len x batch outputs = self.model( src, tgt) # src : src_len x B, outpus : B x tgt_len x vocab # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) outputs = outputs.view(-1, outputs.size(2)) # flatten(0, 1) tgt_output = tgt.transpose(0, 1).reshape( -1) # flatten() # tgt: tgt_len xB , need convert to B x tgt_len loss = self.criterion(outputs, tgt_output) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() self.scheduler.step() loss_item = loss.item() return loss_item def train(self): print("Begin training from iter: ", self.iter) total_loss = 0 total_loader_time = 0 total_gpu_time = 0 best_acc = -1 data_iter = iter(self.train_gen) for i in range(self.num_iters): self.iter += 1 start = time.time() try: batch = next(data_iter) except StopIteration: data_iter = iter(self.train_gen) batch = next(data_iter) total_loader_time += time.time() - start start = time.time() loss = self.step(batch) total_gpu_time += time.time() - start total_loss += loss self.train_losses.append((self.iter, loss)) if self.iter % self.print_every == 0: info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format( self.iter, total_loss / self.print_every, self.optimizer.param_groups[0]['lr'], total_loader_time, total_gpu_time) total_loss = 0 total_loader_time = 0 total_gpu_time = 0 print(info) self.logger.log(info) if self.iter % self.valid_every == 0: val_loss, preds, actuals, inp_sents = self.validate() acc_full_seq, acc_per_char, cer = self.precision(self.metrics) info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f} - CER: {:.4f} '.format( self.iter, val_loss, acc_full_seq, acc_per_char, cer) print(info) print("--- Sentence predict ---") for pred, inp, label in zip(preds, inp_sents, actuals): infor_predict = 'Pred: {} - Inp: {} - Label: {}'.format( pred, inp, label) print(infor_predict) self.logger.log(infor_predict) self.logger.log(info) if acc_full_seq > best_acc: self.save_weights(self.export_weights) best_acc = acc_full_seq self.save_checkpoint(self.checkpoint) def validate(self): self.model.eval() total_loss = [] max_step = self.metrics / self.batch_size with torch.no_grad(): for step, batch in enumerate(self.valid_gen): batch = self.batch_to_device(batch) src, tgt = batch['src'], batch['tgt'] src, tgt = src.transpose(1, 0), tgt.transpose(1, 0) outputs = self.model(src, tgt, 0) # turn off teaching force outputs = outputs.flatten(0, 1) tgt_output = tgt.flatten() loss = self.criterion(outputs, tgt_output) total_loss.append(loss.item()) preds, actuals, inp_sents, probs = self.predict(5) del outputs del loss if step > max_step: break total_loss = np.mean(total_loss) self.model.train() return total_loss, preds[:3], actuals[:3], inp_sents[:3] def predict(self, sample=None): pred_sents = [] actual_sents = [] inp_sents = [] for batch in self.valid_gen: batch = self.batch_to_device(batch) if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['src'], self.model) prob = None else: translated_sentence, prob = translate(batch['src'], self.model) pred_sent = self.vocab.batch_decode(translated_sentence.tolist()) actual_sent = self.vocab.batch_decode(batch['tgt'].tolist()) inp_sent = self.vocab.batch_decode(batch['src'].tolist()) pred_sents.extend(pred_sent) actual_sents.extend(actual_sent) inp_sents.extend(inp_sent) if sample is not None and len(pred_sents) > sample: break return pred_sents, actual_sents, inp_sents, prob def precision(self, sample=None): pred_sents, actual_sents, _, _ = self.predict(sample=sample) acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode='full_sequence') acc_per_char = compute_accuracy(actual_sents, pred_sents, mode='per_char') cer = compute_accuracy(actual_sents, pred_sents, mode='CER') return acc_full_seq, acc_per_char, cer def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16): pred_sents, actual_sents, img_files, probs = self.predict(sample) if errorcase: wrongs = [] for i in range(len(img_files)): if pred_sents[i] != actual_sents[i]: wrongs.append(i) pred_sents = [pred_sents[i] for i in wrongs] actual_sents = [actual_sents[i] for i in wrongs] img_files = [img_files[i] for i in wrongs] probs = [probs[i] for i in wrongs] img_files = img_files[:sample] fontdict = {'family': fontname, 'size': fontsize} def visualize_dataset(self, sample=16, fontname='serif'): n = 0 for batch in self.train_gen: for i in range(self.batch_size): img = batch['img'][i].numpy().transpose(1, 2, 0) sent = self.vocab.decode(batch['tgt_input'].T[i].tolist()) n += 1 if n >= sample: return def load_checkpoint(self, filename): checkpoint = torch.load(filename) self.optimizer.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) self.model.load_state_dict(checkpoint['state_dict']) self.iter = checkpoint['iter'] self.train_losses = checkpoint['train_losses'] def save_checkpoint(self, filename): state = { 'iter': self.iter, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses, 'scheduler': self.scheduler.state_dict() } path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(state, filename) def load_weights(self, filename): state_dict = torch.load(filename, map_location=torch.device(self.device)) for name, param in self.model.named_parameters(): if name not in state_dict: print('{} not found'.format(name)) elif state_dict[name].shape != param.shape: print('{} missmatching shape, required {} but found {}'.format( name, param.shape, state_dict[name].shape)) del state_dict[name] self.model.load_state_dict(state_dict, strict=False) def save_weights(self, filename): path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(self.model.state_dict(), filename) def batch_to_device(self, batch): src = batch['src'].to(self.device, non_blocking=True) tgt = batch['tgt'].to(self.device, non_blocking=True) batch = {'src': src, 'tgt': tgt} return batch
def main(cfg): workdir = Path(cfg.workdir) workdir.mkdir(parents=True, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else 'cpu') set_logger(workdir / 'log.txt') cfg.dump_to_file(workdir / 'config.yml') saver = Saver(workdir, keep_num=10) logging.info(f'config: \n{cfg}') logging.info(f'use device: {device}') model = iqa.__dict__[cfg.model.name](**cfg.model.kwargs) model = model.to(device) if torch.cuda.device_count() > 1: model_dp = nn.DataParallel(model) else: model_dp = model train_transform = Transform( transforms.Compose([ transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor() ])) val_transform = Transform( transforms.Compose([transforms.RandomCrop(224), transforms.ToTensor()])) if not Path(cfg.ava.train_cache).exists(): create_memmap(cfg.ava.train_labels, cfg.ava.images, cfg.ava.train_cache, cfg.num_workers) if not Path(cfg.ava.val_cache).exists(): create_memmap(cfg.ava.train_labels, cfg.ava.images, cfg.ava.val_cache, cfg.num_workers) trainset = MemMap(cfg.ava.train_cache, train_transform) valset = MemMap(cfg.ava.val_cache, val_transform) total_steps = len(trainset) // cfg.batch_size * cfg.num_epochs eval_interval = len(trainset) // cfg.batch_size logging.info(f'total steps: {total_steps}, eval interval: {eval_interval}') model_dp.train() parameters = group_parameters(model) optimizer = SGD(parameters, cfg.lr, cfg.momentum, weight_decay=cfg.weight_decay) lr_scheduler = OneCycleLR(optimizer, max_lr=cfg.lr, div_factor=cfg.lr / cfg.warmup_lr, total_steps=total_steps, pct_start=0.01, final_div_factor=cfg.warmup_lr / cfg.final_lr) train_loader = torch.utils.data.DataLoader(trainset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, drop_last=True, pin_memory=True) val_loader = torch.utils.data.DataLoader(valset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True) curr_loss = 1e9 state = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'step': 0, # init step, 'cfg': cfg, 'loss': curr_loss } saver.save(0, state) trainloader = repeat_loader(train_loader) batch_processor = BatchProcessor(device) start = time.time() for step in range(0, total_steps, eval_interval): num_steps = min(step + eval_interval, total_steps) - step step += num_steps trainmeter = train_steps(model_dp, trainloader, optimizer, lr_scheduler, emd_loss, batch_processor, num_steps) valmeter = evaluate(model_dp, val_loader, emd_loss, batch_processor) finish = time.time() img_s = cfg.batch_size * eval_interval / (finish - start) loss = valmeter.meters['loss'].global_avg state = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'step': step, # init step, 'cfg': cfg, 'loss': loss } saver.save(step, state) if loss < curr_loss: curr_loss = loss saver.save_best(state) logging.info( f'step: [{step}/{total_steps}] img_s: {img_s:.2f} train: [{trainmeter}] eval:[{valmeter}]' ) start = time.time()
def main(args): # prepare workspace workdir = Path(args.workdir) workdir.mkdir(parents=True, exist_ok=True) logger = get_logger(workdir / 'log.txt') logger.info(f'config: \n{args}') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if args.device: logger.info(f'user specify device: {args.device}') device = torch.device(args.device) logger.info(f'use device: {device}') # dump all configues to later use, such as for testing with open(workdir / 'config.yml', 'wt') as f: args.dump(stream=f) saver = Saver(workdir, keep_num=10) # prepare dataset valtransform = ValTransform(dsize=args.dsize) traintransform = TrainTransform(dsize=args.dsize, **args.augments) trainset = WiderFace(args.train_label, args.train_image, min_face=1, with_shapes=True, transform=traintransform) valset = WiderFace(args.val_label, args.val_image, transform=valtransform, min_face=1) trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, collate_fn=wider_collate, drop_last=True) valloader = DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, collate_fn=wider_collate) # model model = models.__dict__[args.model.name](phase='train').to(device) prior = BBoxShapePrior(args.num_classes, 5, args.anchors, args.iou_threshold, args.encode_mean, args.encode_std) model = Detector(prior, model) # optimizer and lr scheduler parameters = group_parameters(model, bias_decay=0) optimizer = SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = OneCycleLR(optimizer, max_lr=args.lr, div_factor=20, total_steps=args.total_steps, pct_start=0.1, final_div_factor=100) trainloader = repeat_loader(trainloader) model.to(device) model.train() best_loss = 1e9 state = { 'model': model.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'optimizer': optimizer.state_dict(), 'step': 0, 'loss': best_loss } saver.save(0, state) def reset_meter(): meter = MetricLogger() meter.add_meter('lr', SmoothedValue(1, fmt='{value:.5f}')) return meter train_meter = reset_meter() start = time.time() for step in range(args.start_step, args.total_steps): batch = next(trainloader) batch = batch_to(batch, device) image = batch['image'] box = batch['bbox'] point = batch['shape'] mask = batch['mask'] label = batch['label'] score_loss, box_loss, point_loss = model(image, targets=(label, box, point, mask)) loss = score_loss + 2.0 * box_loss + point_loss train_meter.meters['score'].update(score_loss.item()) train_meter.meters['box'].update(box_loss.item()) train_meter.meters['shape'].update(point_loss.item()) train_meter.meters['total'].update(loss.item()) train_meter.meters['lr'].update(optimizer.param_groups[0]['lr']) optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() if (step + 1) % args.eval_interval == 0: duration = time.time() - start img_s = args.eval_interval * args.batch_size / duration eval_meter = evaluate(model, valloader, prior, device) logger.info( f'Step [{step + 1}/{args.total_steps}] img/s: {img_s:.2f} train: [{train_meter}] eval: [{eval_meter}]' ) train_meter = reset_meter() start = time.time() curr_loss = eval_meter.meters['total'].global_avg state = { 'model': model.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'optimizer': optimizer.state_dict(), 'step': curr_loss, } saver.save(step + 1, state) if (curr_loss < best_loss): best_loss = curr_loss saver.save_best(state)
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") _logger.info('====================\n\n' 'Actfun: {}\n' 'LR: {}\n' 'Epochs: {}\n' 'p: {}\n' 'k: {}\n' 'g: {}\n' 'Extra channel multiplier: {}\n' 'Weight Init: {}\n' '\n===================='.format(args.actfun, args.lr, args.epochs, args.p, args.k, args.g, args.extra_channel_mult, args.weight_init)) # ================================================================================= Loading models pre_model = create_model( args.model, pretrained=True, actfun='swish', num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint, p=args.p, k=args.k, g=args.g, extra_channel_mult=args.extra_channel_mult, weight_init_name=args.weight_init, partial_ho_actfun=args.partial_ho_actfun) pre_model_layers = list(pre_model.children()) pre_model = torch.nn.Sequential(*pre_model_layers[:-1]) pre_model.to(device) model = MLP.MLP(actfun=args.actfun, input_dim=1280, output_dim=args.num_classes, k=args.k, p=args.p, g=args.g, num_params=1_000_000, permute_type='shuffle') model.to(device) # ================================================================================= Loading dataset util.seed_all(args.seed) if args.data == 'caltech101' and not os.path.exists('caltech101'): dir_root = r'101_ObjectCategories' dir_new = r'caltech101' dir_new_train = os.path.join(dir_new, 'train') dir_new_val = os.path.join(dir_new, 'val') dir_new_test = os.path.join(dir_new, 'test') if not os.path.exists(dir_new): os.mkdir(dir_new) os.mkdir(dir_new_train) os.mkdir(dir_new_val) os.mkdir(dir_new_test) for dir2 in os.listdir(dir_root): if dir2 != 'BACKGROUND_Google': curr_path = os.path.join(dir_root, dir2) new_path_train = os.path.join(dir_new_train, dir2) new_path_val = os.path.join(dir_new_val, dir2) new_path_test = os.path.join(dir_new_test, dir2) if not os.path.exists(new_path_train): os.mkdir(new_path_train) if not os.path.exists(new_path_val): os.mkdir(new_path_val) if not os.path.exists(new_path_test): os.mkdir(new_path_test) train_upper = int(0.8 * len(os.listdir(curr_path))) val_upper = int(0.9 * len(os.listdir(curr_path))) curr_files_all = os.listdir(curr_path) curr_files_train = curr_files_all[:train_upper] curr_files_val = curr_files_all[train_upper:val_upper] curr_files_test = curr_files_all[val_upper:] for file in curr_files_train: copyfile(os.path.join(curr_path, file), os.path.join(new_path_train, file)) for file in curr_files_val: copyfile(os.path.join(curr_path, file), os.path.join(new_path_val, file)) for file in curr_files_test: copyfile(os.path.join(curr_path, file), os.path.join(new_path_test, file)) time.sleep(5) # create the train and eval datasets train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): _logger.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): _logger.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # create data loaders w/ augmentation pipeline train_interpolation = args.train_interpolation data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # ================================================================================= Optimizer / scheduler criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5) scheduler = OneCycleLR( optimizer, max_lr=args.lr, epochs=args.epochs, steps_per_epoch=int(math.floor(len(dataset_train) / args.batch_size)), cycle_momentum=False) # ================================================================================= Save file / checkpoints fieldnames = [ 'dataset', 'seed', 'epoch', 'time', 'actfun', 'model', 'batch_size', 'alpha_primes', 'alphas', 'num_params', 'k', 'p', 'g', 'perm_method', 'gen_gap', 'epoch_train_loss', 'epoch_train_acc', 'epoch_aug_train_loss', 'epoch_aug_train_acc', 'epoch_val_loss', 'epoch_val_acc', 'curr_lr', 'found_lr', 'epochs' ] filename = 'out_{}_{}_{}_{}'.format(datetime.date.today(), args.actfun, args.data, args.seed) outfile_path = os.path.join(args.output, filename) + '.csv' checkpoint_path = os.path.join(args.check_path, filename) + '.pth' if not os.path.exists(outfile_path): with open(outfile_path, mode='w') as out_file: writer = csv.DictWriter(out_file, fieldnames=fieldnames, lineterminator='\n') writer.writeheader() epoch = 1 checkpoint = torch.load(checkpoint_path) if os.path.exists( checkpoint_path) else None if checkpoint is not None: pre_model.load_state_dict(checkpoint['pre_model_state_dict']) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) epoch = checkpoint['epoch'] pre_model.to(device) model.to(device) print("*** LOADED CHECKPOINT ***" "\n{}" "\nSeed: {}" "\nEpoch: {}" "\nActfun: {}" "\np: {}" "\nk: {}" "\ng: {}" "\nperm_method: {}".format(checkpoint_path, checkpoint['curr_seed'], checkpoint['epoch'], checkpoint['actfun'], checkpoint['p'], checkpoint['k'], checkpoint['g'], checkpoint['perm_method'])) args.mix_pre_apex = False if args.control_amp == 'apex': args.mix_pre_apex = True model, optimizer = amp.initialize(model, optimizer, opt_level="O2") # ================================================================================= Training while epoch <= args.epochs: if args.check_path != '': torch.save( { 'pre_model_state_dict': pre_model.state_dict(), 'model_state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'curr_seed': args.seed, 'epoch': epoch, 'actfun': args.actfun, 'p': args.p, 'k': args.k, 'g': args.g, 'perm_method': 'shuffle' }, checkpoint_path) util.seed_all((args.seed * args.epochs) + epoch) start_time = time.time() args.mix_pre = False if args.control_amp == 'native': args.mix_pre = True scaler = torch.cuda.amp.GradScaler() # ---- Training model.train() total_train_loss, n, num_correct, num_total = 0, 0, 0, 0 for batch_idx, (x, targetx) in enumerate(loader_train): x, targetx = x.to(device), targetx.to(device) optimizer.zero_grad() if args.mix_pre: with torch.cuda.amp.autocast(): with torch.no_grad(): x = pre_model(x) output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 scaler.scale(train_loss).backward() scaler.step(optimizer) scaler.update() elif args.mix_pre_apex: with torch.no_grad(): x = pre_model(x) output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 with amp.scale_loss(train_loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() else: with torch.no_grad(): x = pre_model(x) output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 train_loss.backward() optimizer.step() scheduler.step() _, prediction = torch.max(output.data, 1) num_correct += torch.sum(prediction == targetx.data) num_total += len(prediction) epoch_aug_train_loss = total_train_loss / n epoch_aug_train_acc = num_correct * 1.0 / num_total alpha_primes = [] alphas = [] if model.actfun == 'combinact': for i, layer_alpha_primes in enumerate(model.all_alpha_primes): curr_alpha_primes = torch.mean(layer_alpha_primes, dim=0) curr_alphas = F.softmax(curr_alpha_primes, dim=0).data.tolist() curr_alpha_primes = curr_alpha_primes.tolist() alpha_primes.append(curr_alpha_primes) alphas.append(curr_alphas) model.eval() with torch.no_grad(): total_val_loss, n, num_correct, num_total = 0, 0, 0, 0 for batch_idx, (y, targety) in enumerate(loader_eval): y, targety = y.to(device), targety.to(device) with torch.no_grad(): y = pre_model(y) output = model(y) val_loss = criterion(output, targety) total_val_loss += val_loss n += 1 _, prediction = torch.max(output.data, 1) num_correct += torch.sum(prediction == targety.data) num_total += len(prediction) epoch_val_loss = total_val_loss / n epoch_val_acc = num_correct * 1.0 / num_total lr_curr = 0 for param_group in optimizer.param_groups: lr_curr = param_group['lr'] print( " Epoch {}: LR {:1.5f} ||| aug_train_acc {:1.4f} | val_acc {:1.4f} ||| " "aug_train_loss {:1.4f} | val_loss {:1.4f} ||| time = {:1.4f}". format(epoch, lr_curr, epoch_aug_train_acc, epoch_val_acc, epoch_aug_train_loss, epoch_val_loss, (time.time() - start_time)), flush=True) epoch_train_loss = 0 epoch_train_acc = 0 if epoch == args.epochs: with torch.no_grad(): total_train_loss, n, num_correct, num_total = 0, 0, 0, 0 for batch_idx, (x, targetx) in enumerate(loader_train): x, targetx = x.to(device), targetx.to(device) with torch.no_grad(): x = pre_model(x) output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 _, prediction = torch.max(output.data, 1) num_correct += torch.sum(prediction == targetx.data) num_total += len(prediction) epoch_aug_train_loss = total_train_loss / n epoch_aug_train_acc = num_correct * 1.0 / num_total total_train_loss, n, num_correct, num_total = 0, 0, 0, 0 for batch_idx, (x, targetx) in enumerate(loader_eval): x, targetx = x.to(device), targetx.to(device) with torch.no_grad(): x = pre_model(x) output = model(x) train_loss = criterion(output, targetx) total_train_loss += train_loss n += 1 _, prediction = torch.max(output.data, 1) num_correct += torch.sum(prediction == targetx.data) num_total += len(prediction) epoch_train_loss = total_val_loss / n epoch_train_acc = num_correct * 1.0 / num_total # Outputting data to CSV at end of epoch with open(outfile_path, mode='a') as out_file: writer = csv.DictWriter(out_file, fieldnames=fieldnames, lineterminator='\n') writer.writerow({ 'dataset': args.data, 'seed': args.seed, 'epoch': epoch, 'time': (time.time() - start_time), 'actfun': model.actfun, 'model': args.model, 'batch_size': args.batch_size, 'alpha_primes': alpha_primes, 'alphas': alphas, 'num_params': util.get_model_params(model), 'k': args.k, 'p': args.p, 'g': args.g, 'perm_method': 'shuffle', 'gen_gap': float(epoch_val_loss - epoch_train_loss), 'epoch_train_loss': float(epoch_train_loss), 'epoch_train_acc': float(epoch_train_acc), 'epoch_aug_train_loss': float(epoch_aug_train_loss), 'epoch_aug_train_acc': float(epoch_aug_train_acc), 'epoch_val_loss': float(epoch_val_loss), 'epoch_val_acc': float(epoch_val_acc), 'curr_lr': lr_curr, 'found_lr': args.lr, 'epochs': args.epochs }) epoch += 1
def main(): # Training settings and hyperparameters parser = argparse.ArgumentParser(description='SpenceNet Pytorch Training') parser.add_argument('--encoder', default='XResNet34', type=str, choices=['XResNet18', 'XResNet34', 'XResNet50'], help='encoder architecture (default: XResNet34)') parser.add_argument('--num_workers', default=2, type=int, help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=30, type=int, help='number of total training epochs') parser.add_argument('--batch_size', type=int, default=64, help='input batch size for training (default: 64)') parser.add_argument('--use_grayscale', default=True, help='turn input images to grayscale (default: True)') parser.add_argument('--img_size', type=int, default=300, help='target image size for training (default: 300)') parser.add_argument('--max_lr', type=float, default=0.001, help='maximum learning rate (default: 0.001)') parser.add_argument('--encoder_lr_mult', type=float, default=0.25, help='encoder_lr = max_lr * this value (0.25 default)') parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay (default: 0.001)') parser.add_argument('--sched_pct_start', type=float, default=0.3, help='OneCycleLR pct_start parameter (default: 0.3)') parser.add_argument('--sched_div_factor', type=float, default=10.0, help='OneCycleLR div factor (default: 10.0)') parser.add_argument('--wing_loss_e', type=float, default=2.0, help='Wing Loss e parameter (default: 2.0)') parser.add_argument('--wing_loss_w', type=float, default=10.0, help='Wing Loss w parameter (default: 10.0)') parser.add_argument('--use_cuda', default=True, help='Enables CUDA training (default: True)') parser.add_argument('--seed', type=int, default=None, help='fix random seed for training (default: None)') parser.add_argument('--wandb_project', default='multi-head-spencenet', type=str, help='WandB project name') parser.add_argument('--save_dir', default='saved/', type=str, help='directory to save outputs in (default: saved/)') parser.add_argument('--resume', default='', type=str, help='path to checkpoint to optionally resume from') config = parser.parse_args() wandb_config = vars(config) # WandB expects dictionary # Get timestamp today = datetime.now(tz=utc) today = today.astimezone(timezone('US/Pacific')) timestamp = today.strftime("%b_%d_%Y_%H_%M") wandb.init(config=wandb_config, project=config.wandb_project, dir=config.save_dir, name=timestamp, id=timestamp) use_cuda = config.use_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") kwargs = {'num_workers': config.num_workers, 'pin_memory': True} if use_cuda else {} # Fix random seeds and deterministic pytorch for reproducibility if config.seed: torch.manual_seed(config['seed']) # pytorch random seed np.random.seed(config['seed']) # numpy random seed torch.backends.cudnn.deterministic = True # DATASET LOADING # Letters Dictionary >>> Class ID: [Letter Name, # of Coordinate Values] letter_dict = {0: ['alpha', 20], 1: ['beta', 28], 2: ['gamma', 16]} letter_ordered_dict = OrderedDict(sorted(letter_dict.items())) # Define the tranformations train_transforms = transforms.Compose([ lt.RandomCrop(10), lt.RandomRotate(10), lt.RandomLightJitter(0.2), lt.RandomPerspective(0.5), lt.Resize(config.img_size), lt.ToNormalizedTensor() ]) test_transforms = transforms.Compose([ lt.Resize(config.img_size), lt.ToNormalizedTensor() ]) # Add grayscale transform if config.use_grayscale: train_transforms.transforms.insert(0, lt.ToGrayscale()) test_transforms.transforms.insert(0, lt.ToGrayscale()) # Define separate datasets for each annotated class letters = [key for key, val in letter_dict.items() if val[1] != 0] train_ds_list = [] test_ds_list = [] for letter in letters: train_ds_list.append(LetterDataset(f'./data/{letter_dict[letter][0]}_small_data.csv', num_coordinates=letter_dict[letter][1], transform=train_transforms)) test_ds_list.append(LetterDataset(f'./data/{letter_dict[letter][0]}_small_data.csv', is_validation=True, num_coordinates=letter_dict[letter][1], transform=test_transforms)) # Concatenated Datasets train_datasets = ConcatDataset(train_ds_list) test_datasets = ConcatDataset(test_ds_list) # Define Dataloaders with custom LetterBatchSampler train_loader = DataLoader(dataset=train_datasets, sampler=LetterBatchSampler( dataset=train_datasets, batch_size=config.batch_size, drop_last=True), batch_size=config.batch_size, **kwargs) test_loader = DataLoader(dataset=test_datasets, sampler=LetterBatchSampler( dataset=test_datasets, batch_size=config.batch_size, drop_last=True), batch_size=config.batch_size, **kwargs) # INITIALIZE MODEL model = SpenceNet(letter_ordered_dict, backbone=config.encoder, c_in=1 if config.use_grayscale else 3, img_size=config.img_size).to(device) optimizer = optim.AdamW([ {'params': model.encoder.parameters(), 'lr': config.max_lr*config.encoder_lr_mult}, {'params': model.classification_head.parameters()}, {'params': model.keypoint_heads.parameters()} ], lr=config.max_lr, betas=(0.9, 0.99), weight_decay=config.weight_decay) # Initialize Loss Function criterion = MultiLoss(e=config.wing_loss_e, w=config.wing_loss_w) # LR Scheduler scheduler = OneCycleLR(optimizer, max_lr=config.max_lr, pct_start=config.sched_pct_start, div_factor=config.sched_div_factor, steps_per_epoch=len(train_loader), epochs=config.epochs) # Optionally resume from saved checkpoint if config.resume: model, optimizer, scheduler, curr_epoch, ckp_loss = load_checkpoint(config.resume, model, optimizer, scheduler) start_epoch = curr_epoch best_loss = ckp_loss print(f'Resuming from checkpoint... Epoch: {start_epoch} Loss: {best_loss:.4f}') else: start_epoch = 0 best_loss = math.inf # Track all gradients/parameters with WandB wandb.watch(model, log='all') # Training start time training_start = time.time() for epoch in range(start_epoch, config.epochs): train_metrics = train(config, model, device, train_loader, optimizer, scheduler, criterion) test_metrics = test(config, model, device, test_loader, criterion, len(letters)) # Log training data and metrics # TODO: in test, randomly return 4 img per class based on len(letters) log_metrics(timestamp, training_start, epoch, config.epochs, train_metrics, test_metrics) # Checkpoint saving is_best = test_metrics['test_multi_loss'] < best_loss best_loss = min(test_metrics['test_multi_loss'], best_loss) save_checkpoint({'epoch': epoch, 'loss': test_metrics['test_multi_loss'], 'model_state': model.state_dict(), 'opt_state': optimizer.state_dict(), 'sched_state': scheduler.state_dict()}, is_best, checkpoint_dir=f'saved/{timestamp}/')
def main(args): workdir = Path(args.workdir) workdir.mkdir(parents=True, exist_ok=True) logger = get_logger(workdir / 'log.txt') logger.info(f'config:\n{args}') saver = Saver(workdir, keep_num=10) # dump all configues with open(workdir / 'config.yml', 'wt') as f: args.dump(stream=f) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f'use device: {device}') num_points = len(args.data.symmetry) model = models.__dict__[args.model.name](num_points) model.to(device) parameters = group_parameters(model, bias_decay=0) optimizer = SGD(parameters, args.lr, args.momentum, weight_decay=args.weight_decay) lr_scheduler = OneCycleLR(optimizer, max_lr=args.lr, div_factor=20, total_steps=args.total_steps, pct_start=0.1, final_div_factor=100) # datasets valtransform = Transform(args.dsize, args.padding, args.data.meanshape, args.data.meanbbox) traintransform = Transform(args.dsize, args.padding, args.data.meanshape, args.data.meanbbox, args.data.symmetry, args.augments) traindata = datasets.__dict__[args.data.name](**args.data.train) valdata = datasets.__dict__[args.data.name](**args.data.val) traindata.transform = traintransform valdata.transform = valtransform trainloader = DataLoader(traindata, args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers, pin_memory=True) valloader = DataLoader(valdata, args.batch_size, False, num_workers=args.num_workers, pin_memory=False) def repeat(loader): while True: for batch in loader: yield batch best_loss = 1e9 state = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'step': 0, 'loss': best_loss, 'cfg': args } score_fn = IbugScore(args.left_eye, args.right_eye) saver.save(0, state) repeatloader = repeat(trainloader) start = time.time() for step in range(0, args.total_steps, args.eval_interval): num_steps = min(args.eval_interval, args.total_steps - step) step += num_steps trainmeter = train_steps(model, repeatloader, optimizer, lr_scheduler, score_fn, device, num_steps) evalmeter = evaluate(model, valloader, score_fn, device) curr_loss = evalmeter.meters['loss'].global_avg finish = time.time() img_s = num_steps * args.batch_size / (finish - start) logger.info( f'step: [{step}/{args.total_steps}] img/s: {img_s:.2f} train: [{trainmeter}] eval: [{evalmeter}]' ) state = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'step': step, 'loss': curr_loss, 'cfg': args } saver.save(step, state) if curr_loss < best_loss: saver.save_best(state) best_loss = curr_loss start = time.time()
loss_history = [] for loss in train_epoch(model, train_loader, optimizer, scheduler, config.lambda_sparse, device): loss_history.append(loss) logger.info(f"Epoch {ep:03d}/{num_epochs:03d}, train loss: {np.mean(loss_history):.6f}") ### Predict on validation ### logit_log_loss, auc = validation(model, val_loader, device) if scheduler.__class__.__name__ == 'ReduceLROnPlateau': scheduler.step(logit_log_loss) if logit_log_loss < best_logit_log_loss: best_logit_log_loss = logit_log_loss # best_auc = auc write_this = { 'model': model.state_dict(), 'optim': optimizer.state_dict(), 'sched': scheduler.state_dict(), 'epoch': ep, } torch.save(write_this, filepath) logger.info(f" ** Updated the best weight, logit log loss: {logit_log_loss:.6f}, auc: {auc:.6f} **") else: logger.info(f"Passed to save the weight, best: {best_logit_log_loss:.6f} / logit log loss: {logit_log_loss:.6f}, auc: {auc:.6f}") ### Save OOF for CV ### best_state_dict = torch.load(filepath) model.load_state_dict(best_state_dict['model']) val_loader = DataLoader( DatasetWithoutLabel(X_val), batch_size=config.batch_size, collate_fn=collate_fn_test, shuffle=False,