print( '| End of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | valid perplexity {:8.2f}' .format(epoch, epoch_duration, val_loss, math.exp(val_loss))) print('-' * 89) if val_loss < best_val_loss: best_val_loss = val_loss # delete previous checkpoint if model_save_path is not None: os.remove(model_save_path) # save current state model_save_path = os.path.join(exp_dir, f'model_{epoch}.pt') torch.save( { 'model': model.state_dict(), 'optim': optimizer.state_dict(), 'epoch': epoch, 'step': steps }, model_save_path) epochs_wo_improvement = 0 else: epochs_wo_improvement += 1 stats = { 'train_losses': train_losses, 'valid_losses': val_losses, 'mean_epoch_duration': np.mean(durations) } with open(os.path.join(exp_dir, f'stats.pkl'), 'wb') as f: pickle.dump(stats, f)
def main(): global args best_prec1, best_epoch = 0.0, 0 if not os.path.exists(args.save): os.makedirs(args.save) if args.data.startswith('cifar'): IM_SIZE = 32 else: IM_SIZE = 224 model = getattr(models, args.arch)(args) n_flops, n_params = measure_model(model, IM_SIZE, IM_SIZE) torch.save(n_flops, os.path.join(args.save, 'flops.pth')) del(model) model = getattr(models, args.arch)(args) if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() criterion = nn.CrossEntropyLoss().cuda() if args.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'radam': from radam import RAdam optimizer = RAdam(model.parameters(), args.lr, weight_decay=args.weight_decay) else: raise NotImplementedError("Wrong optimizer.") if args.resume: checkpoint = load_checkpoint(args) if checkpoint is not None: args.start_epoch = checkpoint['epoch'] + 1 best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) cudnn.benchmark = True train_loader, val_loader, test_loader = get_dataloaders(args) if args.evalmode is not None: state_dict = torch.load(args.evaluate_from)['state_dict'] model.load_state_dict(state_dict) if args.evalmode == 'anytime': validate(test_loader, model, criterion) else: dynamic_evaluate(model, test_loader, val_loader, args) return scores = ['epoch\tlr\ttrain_loss\tval_loss\ttrain_prec1' '\tval_prec1\ttrain_prec5\tval_prec5'] for epoch in range(args.start_epoch, args.epochs): train_loss, train_prec1, train_prec5, lr = train(train_loader, model, criterion, optimizer, epoch) val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion) scores.append(('{}\t{:.3f}' + '\t{:.4f}' * 6) .format(epoch, lr, train_loss, val_loss, train_prec1, val_prec1, train_prec5, val_prec5)) is_best = val_prec1 > best_prec1 if is_best: best_prec1 = val_prec1 best_epoch = epoch print('Best var_prec1 {}'.format(best_prec1)) model_filename = 'checkpoint_%03d.pth.tar' % epoch save_checkpoint({ 'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, args, is_best, model_filename, scores) print('Best val_prec1: {:.4f} at epoch {}'.format(best_prec1, best_epoch)) ### Test the final model print('********** Final prediction results **********') validate(test_loader, model, criterion) return
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) elif opt.Prediction == 'None': converter = TransformerConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU # model = torch.nn.DataParallel(model).to(device) model = model.to(device) model.train() if opt.load_from_checkpoint: model.load_state_dict(torch.load(os.path.join(opt.load_from_checkpoint, 'checkpoint.pth'))) print(f'loaded checkpoint from {opt.load_from_checkpoint}...') elif opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.SequenceModeling == 'Transformer': fe_state = OrderedDict() state_dict = torch.load(opt.saved_model) for k, v in state_dict.items(): if k.startswith('module.FeatureExtraction'): new_k = re.sub('module.FeatureExtraction.', '', k) fe_state[new_k] = state_dict[k] model.FeatureExtraction.load_state_dict(fe_state) else: if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) if opt.freeze_fe: model.freeze(['FeatureExtraction']) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) elif opt.Prediction == 'None': criterion = LabelSmoothingLoss(classes=converter.n_classes, padding_idx=converter.pad_idx, smoothing=0.1) # criterion = torch.nn.CrossEntropyLoss(ignore_index=converter.pad_idx) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: assert opt.adam in ['Adam', 'AdamW', 'RAdam'], 'adam optimizer must be in Adam, AdamW or RAdam' if opt.adam == 'Adam': optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) elif opt.adam == "AdamW": optimizer = optim.AdamW(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = RAdam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) if opt.load_from_checkpoint and opt.load_optimizer_state: optimizer.load_state_dict(torch.load(os.path.join(opt.load_from_checkpoint, 'optimizer.pth'))) print(f'loaded optimizer state from {os.path.join(opt.load_from_checkpoint, "optimizer.pth")}') """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass if opt.load_from_checkpoint: with open(os.path.join(opt.load_from_checkpoint, 'iter.json'), mode='r', encoding='utf8') as f: start_iter = json.load(f) print(f'continue to train, start_iter: {start_iter}') f.close() start_time = time.time() best_accuracy = -1 best_norm_ED = -1 # i = start_iter bar = tqdm(range(start_iter, opt.num_iter)) # while(True): for i in bar: bar.set_description(f'Iter {i}: train_loss = {loss_avg.val():.5f}') # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a). # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0. # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0. # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707 # cost = criterion(preds, text, preds_size, length) elif opt.Prediction == 'None': tgt_input = text['tgt_input'] tgt_output = text['tgt_output'] tgt_padding_mask = text['tgt_padding_mask'] preds = model(image, tgt_input.transpose(0, 1), tgt_key_padding_mask=tgt_padding_mask,) cost = criterion(preds.view(-1, preds.shape[-1]), tgt_output.contiguous().view(-1)) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if (i + 1) % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') # checkpoint os.makedirs(f'./checkpoints/{opt.experiment_name}/', exist_ok=True) torch.save(model.state_dict(), f'./checkpoints/{opt.experiment_name}/checkpoint.pth') torch.save(optimizer.state_dict(), f'./checkpoints/{opt.experiment_name}/optimizer.pth') with open(f'./checkpoints/{opt.experiment_name}/iter.json', mode='w', encoding='utf8') as f: json.dump(i + 1, f) f.close() with open(f'./checkpoints/{opt.experiment_name}/checkpoint.log', mode='a', encoding='utf8') as f: f.write(f'Saved checkpoint with iter={i}\n') f.write(f'\tCheckpoint at: ./checkpoints/{opt.experiment_name}/checkpoint.pth') f.write(f'\tOptimizer at: ./checkpoints/{opt.experiment_name}/optimizer.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') # if i == opt.num_iter: # print('end the training') # sys.exit() # i += 1 # if i == 1: break print('end training')
class Trainer(object): '''This class takes care of training and validation of our model''' def __init__(self, model): self.fold = args.fold self.total_folds = 5 self.num_workers = 6 self.batch_size = { "train": args.batch_size, "val": args.batch_size } # 4 self.accumulation_steps = 32 // self.batch_size['train'] self.lr = args.learning_rate self.num_epochs = args.epochs self.best_loss = float("inf") self.best_dice = 0 self.phases = ["train", "val"] self.device = torch.device("cuda:0") torch.set_default_tensor_type("torch.cuda.FloatTensor") self.net = model self.criterion = MixedLoss(10.0, 2.0) if args.swa is True: # base_opt = torch.optim.SGD(self.net.parameters(), lr=args.max_lr, momentum=args.momentum, weight_decay=args.weight_decay) base_opt = RAdam(self.net.parameters(), lr=self.lr) self.optimizer = SWA(base_opt, swa_start=38, swa_freq=1, swa_lr=args.min_lr) # self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, scheduler_step, args.min_lr) else: if args.optimizer.lower() == 'adam': self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr) elif args.optimizer.lower() == 'radam': self.optimizer = RAdam( self.net.parameters(), lr=self.lr ) # betas=(args.beta1, args.beta2),weight_decay=args.weight_decay elif args.optimizer.lower() == 'sgd': self.optimizer = torch.optim.SGD( self.net.parameters(), lr=args.max_lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.scheduler.lower() == 'reducelronplateau': self.scheduler = ReduceLROnPlateau(self.optimizer, mode="min", patience=args.patience, verbose=True) elif args.scheduler.lower() == 'clr': self.scheduler = CyclicLR(self.optimizer, base_lr=self.lr, max_lr=args.max_lr) self.net = self.net.to(self.device) cudnn.benchmark = True self.dataloaders = { phase: provider( fold=args.fold, total_folds=5, data_folder=data_folder, df_path=train_rle_path, phase=phase, size=args.img_size_target, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), batch_size=self.batch_size[phase], num_workers=self.num_workers, ) for phase in self.phases } self.losses = {phase: [] for phase in self.phases} self.iou_scores = {phase: [] for phase in self.phases} self.dice_scores = {phase: [] for phase in self.phases} self.kaggle_metric = {phase: [] for phase in self.phases} def forward(self, images, targets): images = images.to(self.device) masks = targets.to(self.device) outputs = self.net(images) loss = self.criterion( outputs, masks ) # weighted_lovasz # lovasz_hinge(outputs, masks) # self.criterion(outputs, masks) return loss, outputs def iterate(self, epoch, phase): meter = Meter(phase, epoch) start = time.strftime("%H:%M:%S") print(f"Starting epoch: {epoch} | phase: {phase} | ⏰: {start}") batch_size = self.batch_size[phase] start = time.time() self.net.train(phase == "train") dataloader = self.dataloaders[phase] running_loss = 0.0 total_batches = len(dataloader) tk0 = tqdm(dataloader, total=total_batches) self.optimizer.zero_grad() for itr, batch in enumerate(tk0): images, targets = batch loss, outputs = self.forward(images, targets) loss = loss / self.accumulation_steps if phase == "train": loss.backward() if (itr + 1) % self.accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() running_loss += loss.item() outputs = outputs.detach().cpu() meter.update(targets, outputs) tk0.set_postfix(loss=(running_loss / ((itr + 1)))) if args.swa is True: self.optimizer.swap_swa_sgd() epoch_loss = (running_loss * self.accumulation_steps ) / total_batches # running_loss / total_batches dice, iou, scores, kaggle_metric = epoch_log(phase, epoch, epoch_loss, meter, start) # kaggle_metric write_event(log, dice, loss=epoch_loss) self.losses[phase].append(epoch_loss) self.dice_scores[phase].append(dice) self.iou_scores[phase].append(iou) self.kaggle_metric[phase].append(kaggle_metric) torch.cuda.empty_cache() return epoch_loss, dice, iou, scores, kaggle_metric # kaggle_metric def start(self): if os.path.exists(args.log_path + 'v' + str(args.version) + '/' + str(args.fold)): shutil.rmtree(args.log_path + 'v' + str(args.version) + '/' + str(args.fold)) else: os.makedirs(args.log_path + 'v' + str(args.version) + '/' + str(args.fold)) writer = SummaryWriter(args.log_path + 'v' + str(args.version) + '/' + str(args.fold)) num_snapshot = 0 best_acc = 0 model_path = args.weights_path + 'v' + str( args.version) + '/' + save_model_name if os.path.exists(model_path): state = torch.load(model_path, map_location=lambda storage, loc: storage) model.load_state_dict(state["state_dict"]) # ["state_dict"] epoch = state['epoch'] self.best_loss = state['best_loss'] self.best_dice = state['best_dice'] state['state_dict'] = state['state_dict'] state['optimizer'] = state['optimizer'] else: epoch = 1 self.best_loss = float('inf') self.best_dice = 0 for epoch in range(epoch, self.num_epochs + 1): print('-' * 30, 'Epoch:', epoch, '-' * 30) train_loss, train_dice, train_iou, train_scores, train_kaggle_metric = self.iterate( epoch, "train") # train_kaggle_metric state = { "epoch": epoch, "best_loss": self.best_loss, "best_dice": self.best_dice, "state_dict": self.net.state_dict(), "optimizer": self.optimizer.state_dict(), } try: val_loss, val_dice, val_iou, val_scores, val_kaggle_metric = self.iterate( epoch, "val") # val_kaggle_metric self.scheduler.step(val_loss) if val_loss < self.best_loss: print("******** New optimal found, saving state ********") state["best_loss"] = self.best_loss = val_loss torch.save(state, model_path) try: scores = val_scores except: scores = 'None' if val_dice > self.best_dice: print("******** Best Dice Score, saving state ********") state["best_dice"] = self.best_dice = val_dice best_dice__path = args.weights_path + 'v' + str( args.version ) + '/' + 'best_dice_' + basic_name + '.pth' torch.save(state, best_dice__path) # if val_dice > best_acc: # print("******** New optimal found, saving state ********") # # state["best_acc"] = self.best_acc = val_dice # best_acc = val_dice # best_param = self.net.state_dict() # if (epoch + 1) % scheduler_step == 0: # # torch.save(best_param, args.save_weight + args.weight_name + str(idx) + str(num_snapshot) + '.pth') # save_model_name = basic_name + '.pth' # '_' +str(num_snapshot) # torch.save(best_param, args.weights_path + 'v' + str(args.version) + '/' + save_model_name) # # state # try: # scores = val_scores # except: # scores = 'None' # optimizer = torch.optim.SGD(self.net.parameters(), lr=args.max_lr, momentum=args.momentum, # weight_decay=args.weight_decay) # self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, scheduler_step, args.min_lr) # num_snapshot += 1 # best_acc = 0 writer.add_scalars('loss', { 'train': train_loss, 'val': val_loss }, epoch) writer.add_scalars('dice_score', { 'train': train_dice, 'val': val_dice }, epoch) writer.add_scalars('IoU', { 'train': train_iou, 'val': val_iou }, epoch) writer.add_scalars('New_Dice', { 'train': train_kaggle_metric, 'val': val_kaggle_metric }, epoch) except KeyboardInterrupt: print('Ctrl+C, saving snapshot') torch.save( state, args.weights_path + 'v' + str(args.version) + '/' + save_model_name) print('done.') # writer.add_scalars('Accuracy', {'train': train_kaggle_metric, 'val': val_kaggle_metric}, epoch) # writer.export_scalars_to_json(args.log_path + 'v' + str(args.version) + '/' + basic_name + '.json') writer.close() return scores
# save checkpoint if ((epoch + 1) % check_point == 0) or (epoch == (num_epoch - 1)) or epoch + 1 > 90 or bleu_score > 4: model_check_point = '%s/model_trainable_%d.pk' % (save_folder, epoch + 1) optim_check_point = '%s/optim_trainable_%d.pkl' % (save_folder, epoch + 1) loss_check_point = '%s/loss_trainable_%d.pkl' % (save_folder, epoch + 1) epoch_check_point = '%s/epoch_trainable_%d.pkl' % (save_folder, epoch + 1) bleu_check_point = '%s/bleu_trainable_%d.pkl' % (save_folder, epoch + 1) torch.save(model.state_dict(), model_check_point) torch.save(optimizer.state_dict(), optim_check_point) torch.save(loss_values, loss_check_point) torch.save(epoch_values, epoch_check_point) torch.save(bleu_values, bleu_check_point) # save current best result if bleu_score > best_bleu: best_bleu = bleu_score print('current best bleu: %.4f' % best_bleu) model_check_point = '%s/model_best_%d.pk' % (save_folder, epoch + 1) optim_check_point = '%s/optim_best_%d.pkl' % (save_folder, epoch + 1) loss_check_point = '%s/loss_best_%d.pkl' % (save_folder, epoch + 1) epoch_check_point = '%s/epoch_best_%d.pkl' % (save_folder, epoch + 1) bleu_check_point = '%s/bleu_best_%d.pkl' % (save_folder, epoch + 1) torch.save(model.state_dict(), model_check_point) torch.save(optimizer.state_dict(), optim_check_point)
class Trainer: def __init__(self, model, train_loader, test_loader, epochs=200, batch_size=60, run_id=0, logs_dir='logs', device='cpu', saturation_device=None, optimizer='None', plot=True, compute_top_k=False, data_prallel=False, conv_method='channelwise', thresh=.99, half_precision=False, downsampling=None): self.saturation_device = device if saturation_device is None else saturation_device self.device = device self.model = model self.epochs = epochs self.plot = plot self.compute_top_k = compute_top_k if 'cuda' in device: cudnn.benchmark = True self.train_loader = train_loader self.test_loader = test_loader self.criterion = nn.CrossEntropyLoss() print('Checking for optimizer for {}'.format(optimizer)) #optimizer = str(optimizer) if optimizer == "adam": print('Using adam') self.optimizer = optim.Adam(model.parameters()) elif optimizer == "adam_lr": print("Using adam with higher learning rate") self.optimizer = optim.Adam(model.parameters(), lr=0.01) elif optimizer == 'adam_lr2': print('Using adam with to large learning rate') self.optimizer = optim.Adam(model.parameters(), lr=0.0001) elif optimizer == "SGD": print('Using SGD') self.optimizer = optim.SGD(model.parameters(), momentum=0.9, weight_decay=5e-4) elif optimizer == "LRS": print('Using LRS') self.optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) self.lr_scheduler = optim.lr_scheduler.StepLR( self.optimizer, self.epochs // 3) elif optimizer == "radam": print('Using radam') self.optimizer = RAdam(model.parameters()) else: raise ValueError('Unknown optimizer {}'.format(optimizer)) self.opt_name = optimizer save_dir = os.path.join(logs_dir, model.name, train_loader.name) if not os.path.exists(save_dir): os.makedirs(save_dir) self.savepath = os.path.join( save_dir, f'{model.name}_bs{batch_size}_e{epochs}_dspl{downsampling}_t{int(thresh*1000)}_id{run_id}.csv' ) self.experiment_done = False if os.path.exists(self.savepath): trained_epochs = len(pd.read_csv(self.savepath, sep=';')) if trained_epochs >= epochs: self.experiment_done = True print( f'Experiment Logs for the exact same experiment with identical run_id was detected, training will be skipped, consider using another run_id' ) if os.path.exists((self.savepath.replace('.csv', '.pt'))): self.model.load_state_dict( torch.load(self.savepath.replace('.csv', '.pt'))['model_state_dict']) if data_prallel: self.model = nn.DataParallel(self.model) self.model = self.model.to(self.device) if half_precision: self.model = self.model.half() self.optimizer.load_state_dict( torch.load(self.savepath.replace('.csv', '.pt'))['optimizer']) self.start_epoch = torch.load(self.savepath.replace( '.csv', '.pt'))['epoch'] + 1 initial_epoch = self._infer_initial_epoch(self.savepath) print('Resuming existing run, starting at epoch', self.start_epoch, 'from', self.savepath.replace('.csv', '.pt')) else: if half_precision: self.model = self.model.half() self.start_epoch = 0 initial_epoch = 0 self.parallel = data_prallel if data_prallel: self.model = nn.DataParallel(self.model) self.model = self.model.to(self.device) writer = CSVandPlottingWriter(self.savepath.replace('.csv', ''), fontsize=16, primary_metric='test_accuracy') writer2 = NPYWriter(self.savepath.replace('.csv', '')) self.pooling_strat = conv_method print('Settomg Satiraton recording threshold to', thresh) self.half = half_precision self.stats = CheckLayerSat(self.savepath.replace('.csv', ''), [writer], model, ignore_layer_names='convolution', stats=['lsat', 'idim'], sat_threshold=.99, verbose=False, conv_method=conv_method, log_interval=1, device=self.saturation_device, reset_covariance=True, max_samples=None, initial_epoch=initial_epoch, interpolation_strategy='nearest' if downsampling is not None else None, interpolation_downsampling=4) def _infer_initial_epoch(self, savepath): if not os.path.exists(savepath): return 0 else: df = pd.read_csv(savepath, sep=';', index_col=0) print(len(df) + 1) return len(df) def train(self): if self.experiment_done: return for epoch in range(self.start_epoch, self.epochs): #self.test(epoch=epoch) print('Start training epoch', epoch) print( "{} Epoch {}, training loss: {}, training accuracy: {}".format( now(), epoch, *self.train_epoch())) self.test(epoch=epoch) if self.opt_name == "LRS": print('LRS step') self.lr_scheduler.step() self.stats.add_saturations() #self.stats.save() #if self.plot: # plot_saturation_level_from_results(self.savepath, epoch) self.stats.close() return self.savepath + '.csv' def train_epoch(self): self.model.train() correct = 0 total = 0 running_loss = 0 old_time = time() top5_accumulator = 0 for batch, data in enumerate(self.train_loader): if batch % 10 == 0 and batch != 0: print( batch, 'of', len(self.train_loader), 'processing time', time() - old_time, "top5_acc:" if self.compute_top_k else 'acc:', round(top5_accumulator / (batch), 3) if self.compute_top_k else correct / total) old_time = time() inputs, labels = data if self.half: inputs, labels = inputs.to(self.device).half(), labels.to( self.device) else: inputs, labels = inputs.to(self.device), labels.to(self.device) self.optimizer.zero_grad() outputs = self.model(inputs) if self.compute_top_k: top5_accumulator += accuracy(outputs, labels, (5, ))[0] _, predicted = torch.max(outputs.data, 1) total += labels.size(0) loss = self.criterion(outputs, labels) loss.backward() self.optimizer.step() correct += (predicted == labels.long()).sum().item() running_loss += loss.item() self.stats.add_scalar('training_loss', running_loss / total) if self.compute_top_k: self.stats.add_scalar('training_accuracy', (top5_accumulator / (batch + 1))) else: self.stats.add_scalar('training_accuracy', correct / total) return running_loss / total, correct / total def test(self, epoch, save=True): self.model.eval() correct = 0 total = 0 test_loss = 0 top5_accumulator = 0 with torch.no_grad(): for batch, data in enumerate(self.test_loader): if batch % 10 == 0: print('Processing eval batch', batch, 'of', len(self.test_loader)) inputs, labels = data if self.half: inputs, labels = inputs.to(self.device).half(), labels.to( self.device) else: inputs, labels = inputs.to(self.device), labels.to( self.device) outputs = self.model(inputs) loss = self.criterion(outputs, labels) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels.long()).sum().item() if self.compute_top_k: top5_accumulator += accuracy(outputs, labels, (5, ))[0] test_loss += loss.item() self.stats.add_scalar('test_loss', test_loss / total) if self.compute_top_k: self.stats.add_scalar('test_accuracy', top5_accumulator / (batch + 1)) print('{} Test Top5-Accuracy on {} images: {:.4f}'.format( now(), total, top5_accumulator / (batch + 1))) else: self.stats.add_scalar('test_accuracy', correct / total) print('{} Test Accuracy on {} images: {:.4f}'.format( now(), total, correct / total)) if save: torch.save( { 'model_state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'epoch': epoch, 'test_loss': test_loss / total }, self.savepath.replace('.csv', '.pt')) return correct / total, test_loss / total