def main(args): if args.checkpoint == '': args.checkpoint = "checkpoints/ic15_%s_bs_%d_ep_%d"%(args.arch, args.batch_size, args.n_epoch) if args.pretrain: if 'synth' in args.pretrain: args.checkpoint += "_pretrain_synth" else: args.checkpoint += "_pretrain_ic17" print ('checkpoint path: %s'%args.checkpoint) print ('init lr: %.8f'%args.lr) print ('schedule: ', args.schedule) sys.stdout.flush() if not os.path.isdir(args.checkpoint): os.makedirs(args.checkpoint) kernel_num = 7 min_scale = 0.4 start_epoch = 0 data_loader = IC15Loader(is_transform=True, img_size=args.img_size, kernel_num=kernel_num, min_scale=min_scale) train_loader = torch.utils.data.DataLoader( data_loader, batch_size=args.batch_size, shuffle=True, num_workers=3, drop_last=True, pin_memory=True) if args.arch == "resnet50": model = models.resnet50(pretrained=True, num_classes=kernel_num) elif args.arch == "resnet101": model = models.resnet101(pretrained=True, num_classes=kernel_num) elif args.arch == "resnet152": model = models.resnet152(pretrained=True, num_classes=kernel_num) model = torch.nn.DataParallel(model).cuda() if hasattr(model.module, 'optimizer'): optimizer = model.module.optimizer else: optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.99, weight_decay=5e-4) title = 'icdar2015' if args.pretrain: print('Using pretrained model.') assert os.path.isfile(args.pretrain), 'Error: no checkpoint directory found!' checkpoint = torch.load(args.pretrain) model.load_state_dict(checkpoint['state_dict']) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) logger.set_names(['Learning Rate', 'Train Loss','Train Acc.', 'Train IOU.']) elif args.resume: print('Resuming from checkpoint.') assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) else: print('Training from scratch.') logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) logger.set_names(['Learning Rate', 'Train Loss','Train Acc.', 'Train IOU.']) for epoch in range(start_epoch, args.n_epoch): adjust_learning_rate(args, optimizer, epoch) print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr'])) train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train(train_loader, model, dice_loss, optimizer, epoch) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'lr': args.lr, 'optimizer' : optimizer.state_dict(), }, checkpoint=args.checkpoint) logger.append([optimizer.param_groups[0]['lr'], train_loss, train_te_acc, train_te_iou]) logger.close()