def main(args): #initial setup if args.checkpoint == '': args.checkpoint = "checkpoints1/ic19val_%s_bs_%d_ep_%d" % ( args.arch, args.batch_size, args.n_epoch) if args.pretrain: if 'synth' in args.pretrain: args.checkpoint += "_pretrain_synth" else: args.checkpoint += "_pretrain_ic17" print(('checkpoint path: %s' % args.checkpoint)) print(('init lr: %.8f' % args.lr)) print(('schedule: ', args.schedule)) sys.stdout.flush() if not os.path.isdir(args.checkpoint): os.makedirs(args.checkpoint) kernel_num = 7 min_scale = 0.4 start_epoch = 0 validation_split = 0.1 random_seed = 42 prev_val_loss = -1 val_loss_list = [] loggertf = tfLogger('./log/' + args.arch) #end #setup data loaders data_loader = IC19Loader(is_transform=True, img_size=args.img_size, kernel_num=kernel_num, min_scale=min_scale) dataset_size = len(data_loader) indices = list(range(dataset_size)) split = int(np.floor(validation_split * dataset_size)) np.random.seed(random_seed) np.random.shuffle(indices) train_incidies, val_indices = indices[split:], indices[:split] train_sampler = SubsetRandomSampler(train_incidies) validate_sampler = SubsetRandomSampler(val_indices) train_loader = torch.utils.data.DataLoader(data_loader, batch_size=args.batch_size, num_workers=3, drop_last=True, pin_memory=True, sampler=train_sampler) validate_loader = torch.utils.data.DataLoader(data_loader, batch_size=args.batch_size, num_workers=3, drop_last=True, pin_memory=True, sampler=validate_sampler) #end #Setup architecture and optimizer if args.arch == "resnet50": model = models.resnet50(pretrained=True, num_classes=kernel_num) elif args.arch == "resnet101": model = models.resnet101(pretrained=True, num_classes=kernel_num) elif args.arch == "resnet152": model = models.resnet152(pretrained=True, num_classes=kernel_num) elif args.arch == "resPAnet50": model = models.resPAnet50(pretrained=True, num_classes=kernel_num) elif args.arch == "resPAnet101": model = models.resPAnet101(pretrained=True, num_classes=kernel_num) elif args.arch == "resPAnet152": model = models.resPAnet152(pretrained=True, num_classes=kernel_num) model = torch.nn.DataParallel(model).cuda() if hasattr(model.module, 'optimizer'): optimizer = model.module.optimizer else: optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.99, weight_decay=5e-4) #end #options to resume/use pretrained model/train from scratch title = 'icdar2019MLT' if args.pretrain: print('Using pretrained model.') assert os.path.isfile( args.pretrain), 'Error: no checkpoint directory found!' checkpoint = torch.load(args.pretrain) model.load_state_dict(checkpoint['state_dict']) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) logger.set_names([ 'Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.', 'Validate Loss', 'Validate Acc', 'Validate IOU' ]) elif args.resume: print('Resuming from checkpoint.') assert os.path.isfile( args.resume), 'Error: no checkpoint directory found!' checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) else: print('Training from scratch.') logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) logger.set_names([ 'Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.', 'Validate Loss', 'Validate Acc', 'Validate IOU' ]) #end #start training model for epoch in range(start_epoch, args.n_epoch): adjust_learning_rate(args, optimizer, epoch) print(('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr']))) train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train( train_loader, model, dice_loss, optimizer, epoch, loggertf) val_loss, val_te_acc, val_ke_acc, val_te_iou, val_ke_iou = validate( validate_loader, model, dice_loss) #logging on tensorboard loggertf.scalar_summary('Training/Accuracy', train_te_acc, epoch + 1) loggertf.scalar_summary('Training/Loss', train_loss, epoch + 1) loggertf.scalar_summary('Training/IoU', train_te_iou, epoch + 1) loggertf.scalar_summary('Validation/Accuracy', val_te_acc, epoch + 1) loggertf.scalar_summary('Validation/Loss', val_loss, epoch + 1) loggertf.scalar_summary('Validation/IoU', val_te_iou, epoch + 1) #end #Boring Book Keeping print(("End of Epoch %d", epoch + 1)) print(( "Train Loss: {loss:.4f} | Train Acc: {acc: .4f} | Train IOU: {iou_t: .4f}" .format(loss=train_loss, acc=train_te_acc, iou_t=train_te_iou))) print(( "Validation Loss: {loss:.4f} | Validation Acc: {acc: .4f} | Validation IOU: {iou_t: .4f}" .format(loss=val_loss, acc=val_te_acc, iou_t=val_te_iou))) #end #Saving improving and Best Models val_loss_list.append(val_loss) if (val_loss < prev_val_loss or prev_val_loss == -1): checkpointname = "{loss:.3f}".format( loss=val_loss) + "_epoch" + str(epoch + 1) + "_checkpoint.pth.tar" save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'lr': args.lr, 'optimizer': optimizer.state_dict(), }, checkpoint=args.checkpoint, filename=checkpointname) if (val_loss < min(val_loss_list)): checkpointname = "best_checkpoint.pth.tar" save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'lr': args.lr, 'optimizer': optimizer.state_dict(), }, checkpoint=args.checkpoint, filename=checkpointname) #end prev_val_loss = val_loss logger.append([ optimizer.param_groups[0]['lr'], train_loss, train_te_acc, train_te_iou, val_loss, val_te_acc, val_te_iou ]) #end traing model logger.close()
def main(args): if args.checkpoint == '': args.checkpoint = "checkpoints/ctw1500_%s_bs_%d_ep_%d" % ( args.arch, args.batch_size, args.n_epoch) if args.pretrain: if 'synth' in args.pretrain: args.checkpoint += "_pretrain_synth" else: args.checkpoint += "_pretrain_ic17" print('checkpoint path: %s' % args.checkpoint) print('init lr: %.8f' % args.lr) print('schedule: ', args.schedule) sys.stdout.flush() if not os.path.isdir(args.checkpoint): os.makedirs(args.checkpoint) kernel_num = 7 min_scale = 0.4 start_epoch = 0 data_loader = CTW1500Loader(is_transform=True, img_size=args.img_size, kernel_num=kernel_num, min_scale=min_scale) train_loader = torch.utils.data.DataLoader(data_loader, batch_size=args.batch_size, shuffle=True, num_workers=3, drop_last=True, pin_memory=True) if args.arch == "resnet50": model = models.resnet50(pretrained=True, num_classes=kernel_num) elif args.arch == "resnet101": model = models.resnet101(pretrained=True, num_classes=kernel_num) elif args.arch == "resnet152": model = models.resnet152(pretrained=True, num_classes=kernel_num) model = torch.nn.DataParallel(model).cuda() if hasattr(model.module, 'optimizer'): optimizer = model.module.optimizer else: optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.99, weight_decay=5e-4) title = 'CTW1500' if args.pretrain: print('Using pretrained model.') assert os.path.isfile( args.pretrain), 'Error: no checkpoint directory found!' checkpoint = torch.load(args.pretrain) model.load_state_dict(checkpoint['state_dict']) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) logger.set_names( ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.']) elif args.resume: print('Resuming from checkpoint.') assert os.path.isfile( args.resume), 'Error: no checkpoint directory found!' checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) else: print('Training from scratch.') logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) logger.set_names( ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.']) for epoch in range(start_epoch, args.n_epoch): adjust_learning_rate(args, optimizer, epoch) print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr'])) train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train( train_loader, model, dice_loss, optimizer, epoch) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'lr': args.lr, 'optimizer': optimizer.state_dict(), }, checkpoint=args.checkpoint) logger.append([ optimizer.param_groups[0]['lr'], train_loss, train_te_acc, train_te_iou ]) logger.close()