def setup(self): args = self.args sub_dir = 'input-{}_wot-{}_wtv-{}_reg-{}_nIter-{}_normCood-{}'.format( args.crop_size, args.wot, args.wtv, args.reg, args.num_of_iter_in_ot, args.norm_cood) self.save_dir = os.path.join('ckpts', sub_dir) if not os.path.exists(self.save_dir): os.makedirs(self.save_dir) time_str = datetime.strftime(datetime.now(), '%m%d-%H%M%S') self.logger = log_utils.get_logger( os.path.join(self.save_dir, 'train-{:s}.log'.format(time_str))) log_utils.print_config(vars(args), self.logger) if torch.cuda.is_available(): self.device = torch.device("cuda") self.device_count = torch.cuda.device_count() assert self.device_count == 1 self.logger.info('using {} gpus'.format(self.device_count)) else: raise Exception("gpu is not available") downsample_ratio = 8 if args.dataset.lower() == 'qnrf': self.datasets = { x: Crowd_qnrf(os.path.join(args.data_dir, x), args.crop_size, downsample_ratio, x) for x in ['train', 'val'] } elif args.dataset.lower() == 'nwpu': self.datasets = { x: Crowd_nwpu(os.path.join(args.data_dir, x), args.crop_size, downsample_ratio, x) for x in ['train', 'val'] } elif args.dataset.lower() == 'sha' or args.dataset.lower() == 'shb': self.datasets = { 'train': Crowd_sh(os.path.join(args.data_dir, 'train_data'), args.crop_size, downsample_ratio, 'train'), 'val': Crowd_sh(os.path.join(args.data_dir, 'test_data'), args.crop_size, downsample_ratio, 'val'), } else: raise NotImplementedError self.dataloaders = { x: DataLoader(self.datasets[x], collate_fn=(train_collate if x == 'train' else default_collate), batch_size=(args.batch_size if x == 'train' else 1), shuffle=(True if x == 'train' else False), num_workers=args.num_workers * self.device_count, pin_memory=(True if x == 'train' else False)) for x in ['train', 'val'] } #self.model = vgg19() self.model = TR_CC() self.model.to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay) self.start_epoch = 0 if args.resume: self.logger.info('loading pretrained model from ' + args.resume) suf = args.resume.rsplit('.', 1)[-1] if suf == 'tar': checkpoint = torch.load(args.resume, self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict( checkpoint['optimizer_state_dict']) self.start_epoch = checkpoint['epoch'] + 1 elif suf == 'pth': self.model.load_state_dict(torch.load(args.resume, self.device)) else: self.logger.info('random initialization') self.ot_loss = OT_Loss(args.crop_size, downsample_ratio, args.norm_cood, self.device, args.num_of_iter_in_ot, args.reg) self.tv_loss = nn.L1Loss(reduction='none').to(self.device) self.mse = nn.MSELoss().to(self.device) self.mae = nn.L1Loss().to(self.device) self.save_list = Save_Handle(max_num=1) self.best_mae = np.inf self.best_mse = np.inf self.best_count = 0
def setup(self): train_args = self.train_args datargs = self.datargs sub_dir = 'input-{}_wot-{}_wtv-{}_reg-{}_nIter-{}_normCood-{}'.format( train_args['crop_size'], train_args['wot'], train_args['wtv'], train_args['reg'], train_args['num_of_iter_in_ot'], train_args['norm_cood']) time_str = datetime.strftime(datetime.now(), '%m%d-%H%M%S') self.save_dir = os.path.join(train_args['out_path'], 'ckpts', train_args['conf_name'], train_args['dataset'], sub_dir, time_str) if not os.path.exists(self.save_dir): os.makedirs(self.save_dir) log_dir = os.path.join(train_args['out_path'], 'runs', train_args['dataset'], train_args['conf_name'], time_str) if not os.path.exists(log_dir): os.makedirs(log_dir) # TODO: Verify args self.logger = SummaryWriter(log_dir) if torch.cuda.is_available(): self.device = torch.device("cuda") self.device_count = torch.cuda.device_count() assert self.device_count == 1 else: raise Exception("Gpu is not available") dataset_name = train_args['dataset'].lower() if dataset_name == 'qnrf': from datasets.crowd import Crowd_qnrf as Crowd elif dataset_name == 'nwpu': from datasets.crowd import Crowd_nwpu as Crowd elif dataset_name == 'sha' or dataset_name == 'shb': from datasets.crowd import Crowd_sh as Crowd elif dataset_name[:3] == 'ucf': from datasets.crowd import Crowd_ucf as Crowd else: raise NotImplementedError if dataset_name == 'sha' or dataset_name == 'shb': downsample_ratio = train_args['downsample_ratio'] train_val = Crowd(os.path.join(datargs['data_path'], datargs["train_path"]), crop_size=train_args['crop_size'], downsample_ratio=downsample_ratio, method='train') if dataset_name == 'sha': train_set, val = random_split( train_val, [280, 20], generator=torch.Generator().manual_seed(42)) val_set = ValSubset(val) else: train_set, val = random_split( train_val, [380, 20], generator=torch.Generator().manual_seed(42)) val_set = ValSubset(val) self.datasets = {'train': train_set, 'val': val_set} else: downsample_ratio = train_args['downsample_ratio'] self.datasets = { 'train': Crowd(os.path.join(datargs['data_path'], datargs["train_path"]), crop_size=train_args['crop_size'], downsample_ratio=downsample_ratio, method='train'), 'val': Crowd(os.path.join(datargs['data_path'], datargs["val_path"]), crop_size=train_args['crop_size'], downsample_ratio=downsample_ratio, method='val') } self.dataloaders = { x: DataLoader( self.datasets[x], collate_fn=(train_collate if x == 'train' else default_collate), batch_size=(train_args['batch_size'] if x == 'train' else 1), shuffle=(True if x == 'train' else False), num_workers=train_args['num_workers'] * self.device_count, pin_memory=(True if x == 'train' else False)) for x in ['train', 'val'] } self.model = vgg16dres(map_location=self.device) self.model.to(self.device) # for p in self.model.features.parameters(): # p.requires_grad = True self.optimizer = optim.Adam(self.model.parameters(), lr=train_args['lr'], weight_decay=train_args['weight_decay'], amsgrad=False) # for _, p in zip(range(10000), next(self.model.children()).children()): # p.requires_grad = False # print("freeze: ", p) # print(self.optimizer.param_groups[0]) self.start_epoch = 0 self.ot_loss = OT_Loss(train_args['crop_size'], downsample_ratio, train_args['norm_cood'], self.device, self.logger, train_args['num_of_iter_in_ot'], train_args['reg']) self.tv_loss = nn.L1Loss(reduction='none').to(self.device) self.mse = nn.MSELoss().to(self.device) self.mae = nn.L1Loss().to(self.device) self.save_list = Save_Handle(max_num=1) self.best_mae = np.inf self.best_mse = np.inf self.best_count = 0 if train_args['resume']: self.logger.add_text( 'log/train', 'loading pretrained model from ' + train_args['resume'], 0) suf = train_args['resume'].rsplit('.', 1)[-1] if suf == 'tar': checkpoint = torch.load(train_args['resume'], self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict( checkpoint['optimizer_state_dict']) self.start_epoch = checkpoint['epoch'] + 1 self.best_count = checkpoint['best_count'] self.best_mae = checkpoint['best_mae'] self.best_mse = checkpoint['best_mse'] print(self.best_mae, self.best_mse, self.best_count) elif suf == 'pth': self.model.load_state_dict( torch.load(train_args['resume'], self.device)) else: self.logger.add_text('log/train', 'random initialization', 0) img_cnts = { 'val_image_count': len(self.dataloaders['val']), 'train_image_count': len(self.dataloaders['train']) } self.logger.add_hparams({ **self.train_args, **img_cnts }, { 'best_mse': np.inf, 'best_mae': np.inf, 'best_count': 0 }, run_name='hparams')