def __init__(self, network, w_lr=0.01, w_mom=0.9, w_wd=1e-4, t_lr=0.001, t_wd=3e-3, t_beta=(0.5, 0.999), init_temperature=5.0, temperature_decay=0.965, logger=logging, lr_scheduler={'T_max': 200}, gpus=[0], save_theta_prefix='', theta_result_path='./theta-result', checkpoints_path='./checkpoints'): assert isinstance(network, FBNet) network.apply(weights_init) network = network.train().cuda() if isinstance(gpus, str): gpus = [int(i) for i in gpus.strip().split(',')] network = DataParallel(network, gpus) self.gpus = gpus self._mod = network theta_params = network.theta mod_params = network.parameters() self.theta = theta_params self.w = mod_params self._tem_decay = temperature_decay self.temp = init_temperature self.logger = logger self.save_theta_prefix = save_theta_prefix if not os.path.exists(theta_result_path): os.makedirs(theta_result_path) self.theta_result_path = theta_result_path if not os.path.exists(checkpoints_path): os.makedirs(checkpoints_path) self.checkpoints_path = checkpoints_path self._acc_avg = AvgrageMeter('acc') self._ce_avg = AvgrageMeter('ce') self._lat_avg = AvgrageMeter('lat') self._loss_avg = AvgrageMeter('loss') self.w_opt = torch.optim.SGD(mod_params, w_lr, momentum=w_mom, weight_decay=w_wd) self.w_sche = CosineDecayLR(self.w_opt, **lr_scheduler) self.t_opt = torch.optim.Adam(theta_params, lr=t_lr, betas=t_beta, weight_decay=t_wd)
def main(args): # torch.backends.cudnn.benchmark = True title = args.title if args.checkpoint == '': args.checkpoint = "checkpoints/%s_%s_bs_%d_ep_%d" % ( title, args.arch, args.batch_size, args.n_epoch) if args.pretrain: if 'synth' in args.pretrain: args.checkpoint += "_pretrain_synth" else: args.checkpoint += "_pretrain_ic17" print(('checkpoint path: %s' % args.checkpoint)) print(('init lr: %.8f' % args.lr)) print(('schedule: ', args.schedule)) args.vals = args.vals.split(';') if args.vals else [] print('vals:', args.vals) sys.stdout.flush() if not os.path.isdir(args.checkpoint): os.makedirs(args.checkpoint) kernel_num = 7 min_scale = 0.4 start_epoch = 0 #data_loader = CTW1500Loader(is_transform=True, img_size=args.img_size, kernel_num=kernel_num, min_scale=min_scale) #data_loader = IC15Loader(is_transform=True, img_size=args.img_size, kernel_num=kernel_num, min_scale=min_scale) data_loader = OcrDataLoader(args, is_transform=True, img_size=args.img_size, kernel_num=kernel_num, min_scale=min_scale) train_loader = torch.utils.data.DataLoader(data_loader, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) if args.arch == "resnet50": model = models.resnet50(pretrained=True, num_classes=kernel_num) elif args.arch == "resnet101": model = models.resnet101(pretrained=True, num_classes=kernel_num) elif args.arch == "resnet152": model = models.resnet152(pretrained=True, num_classes=kernel_num) if len(args.gpus) > 1: model = DataParallel(model, device_ids=args.gpus, chunk_sizes=args.chunk_sizes).cuda() optimizer = model.module.optimizer else: model = model.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.99, weight_decay=5e-4) # if hasattr(model.module, 'optimizer'): # optimizer = model.module.optimizer # else: # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.99, weight_decay=5e-4) if args.pretrain: print('Using pretrained model.') assert os.path.isfile( args.pretrain), 'Error: no checkpoint directory found!' checkpoint = torch.load(args.pretrain) model.load_state_dict(checkpoint['state_dict']) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) logger.set_names( ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.']) elif args.resume: print('Resuming from checkpoint.') assert os.path.isfile( args.resume), 'Error: no checkpoint directory found!' checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) else: print('Training from scratch.') logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) logger.set_names( ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.']) best_target = {'epoch': 0, 'val': 0} for epoch in range(start_epoch, args.n_epoch): adjust_learning_rate(args, optimizer, epoch) print(('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr']))) train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train( train_loader, model, dice_loss, optimizer, epoch) # validate if args.vals: target = run_tests(args, model, epoch) # save best model if target > best_target['val']: best_target['val'] = target best_target['epoch'] = epoch + 1 save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'lr': args.lr, 'optimizer': optimizer.state_dict(), }, checkpoint=args.checkpoint, filename='best.pth.tar') print('best_target: epoch: %d, val:%.4f' % (best_target['epoch'], best_target['val'])) # save latest model save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'lr': args.lr, 'optimizer': optimizer.state_dict(), }, checkpoint=args.checkpoint) logger.append([ optimizer.param_groups[0]['lr'], train_loss, train_te_acc, train_te_iou ]) logger.close()