Example #1
0
    def setup(self):
        args = self.args
        sub_dir = 'input-{}_wot-{}_wtv-{}_reg-{}_nIter-{}_normCood-{}'.format(
            args.crop_size, args.wot, args.wtv, args.reg,
            args.num_of_iter_in_ot, args.norm_cood)

        self.save_dir = os.path.join('ckpts', sub_dir)
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        time_str = datetime.strftime(datetime.now(), '%m%d-%H%M%S')
        self.logger = log_utils.get_logger(
            os.path.join(self.save_dir, 'train-{:s}.log'.format(time_str)))
        log_utils.print_config(vars(args), self.logger)

        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_count = torch.cuda.device_count()
            assert self.device_count == 1
            self.logger.info('using {} gpus'.format(self.device_count))
        else:
            raise Exception("gpu is not available")

        downsample_ratio = 8
        if args.dataset.lower() == 'qnrf':
            self.datasets = {
                x: Crowd_qnrf(os.path.join(args.data_dir, x), args.crop_size,
                              downsample_ratio, x)
                for x in ['train', 'val']
            }
        elif args.dataset.lower() == 'nwpu':
            self.datasets = {
                x: Crowd_nwpu(os.path.join(args.data_dir, x), args.crop_size,
                              downsample_ratio, x)
                for x in ['train', 'val']
            }
        elif args.dataset.lower() == 'sha' or args.dataset.lower() == 'shb':
            self.datasets = {
                'train':
                Crowd_sh(os.path.join(args.data_dir, 'train_data'),
                         args.crop_size, downsample_ratio, 'train'),
                'val':
                Crowd_sh(os.path.join(args.data_dir, 'test_data'),
                         args.crop_size, downsample_ratio, 'val'),
            }
        else:
            raise NotImplementedError

        self.dataloaders = {
            x: DataLoader(self.datasets[x],
                          collate_fn=(train_collate
                                      if x == 'train' else default_collate),
                          batch_size=(args.batch_size if x == 'train' else 1),
                          shuffle=(True if x == 'train' else False),
                          num_workers=args.num_workers * self.device_count,
                          pin_memory=(True if x == 'train' else False))
            for x in ['train', 'val']
        }
        #self.model = vgg19()
        self.model = TR_CC()
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.weight_decay)

        self.start_epoch = 0
        if args.resume:
            self.logger.info('loading pretrained model from ' + args.resume)
            suf = args.resume.rsplit('.', 1)[-1]
            if suf == 'tar':
                checkpoint = torch.load(args.resume, self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(
                    checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
            elif suf == 'pth':
                self.model.load_state_dict(torch.load(args.resume,
                                                      self.device))
        else:
            self.logger.info('random initialization')

        self.ot_loss = OT_Loss(args.crop_size, downsample_ratio,
                               args.norm_cood, self.device,
                               args.num_of_iter_in_ot, args.reg)
        self.tv_loss = nn.L1Loss(reduction='none').to(self.device)
        self.mse = nn.MSELoss().to(self.device)
        self.mae = nn.L1Loss().to(self.device)
        self.save_list = Save_Handle(max_num=1)
        self.best_mae = np.inf
        self.best_mse = np.inf
        self.best_count = 0
    def setup(self):
        train_args = self.train_args
        datargs = self.datargs
        sub_dir = 'input-{}_wot-{}_wtv-{}_reg-{}_nIter-{}_normCood-{}'.format(
            train_args['crop_size'], train_args['wot'], train_args['wtv'],
            train_args['reg'], train_args['num_of_iter_in_ot'],
            train_args['norm_cood'])

        time_str = datetime.strftime(datetime.now(), '%m%d-%H%M%S')
        self.save_dir = os.path.join(train_args['out_path'], 'ckpts',
                                     train_args['conf_name'],
                                     train_args['dataset'], sub_dir, time_str)
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        log_dir = os.path.join(train_args['out_path'], 'runs',
                               train_args['dataset'], train_args['conf_name'],
                               time_str)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        # TODO: Verify args
        self.logger = SummaryWriter(log_dir)
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_count = torch.cuda.device_count()
            assert self.device_count == 1
        else:
            raise Exception("Gpu is not available")

        dataset_name = train_args['dataset'].lower()
        if dataset_name == 'qnrf':
            from datasets.crowd import Crowd_qnrf as Crowd
        elif dataset_name == 'nwpu':
            from datasets.crowd import Crowd_nwpu as Crowd
        elif dataset_name == 'sha' or dataset_name == 'shb':
            from datasets.crowd import Crowd_sh as Crowd
        elif dataset_name[:3] == 'ucf':
            from datasets.crowd import Crowd_ucf as Crowd
        else:
            raise NotImplementedError
        if dataset_name == 'sha' or dataset_name == 'shb':
            downsample_ratio = train_args['downsample_ratio']
            train_val = Crowd(os.path.join(datargs['data_path'],
                                           datargs["train_path"]),
                              crop_size=train_args['crop_size'],
                              downsample_ratio=downsample_ratio,
                              method='train')
            if dataset_name == 'sha':
                train_set, val = random_split(
                    train_val, [280, 20],
                    generator=torch.Generator().manual_seed(42))
                val_set = ValSubset(val)
            else:
                train_set, val = random_split(
                    train_val, [380, 20],
                    generator=torch.Generator().manual_seed(42))
                val_set = ValSubset(val)
            self.datasets = {'train': train_set, 'val': val_set}
        else:
            downsample_ratio = train_args['downsample_ratio']
            self.datasets = {
                'train':
                Crowd(os.path.join(datargs['data_path'],
                                   datargs["train_path"]),
                      crop_size=train_args['crop_size'],
                      downsample_ratio=downsample_ratio,
                      method='train'),
                'val':
                Crowd(os.path.join(datargs['data_path'], datargs["val_path"]),
                      crop_size=train_args['crop_size'],
                      downsample_ratio=downsample_ratio,
                      method='val')
            }
        self.dataloaders = {
            x: DataLoader(
                self.datasets[x],
                collate_fn=(train_collate
                            if x == 'train' else default_collate),
                batch_size=(train_args['batch_size'] if x == 'train' else 1),
                shuffle=(True if x == 'train' else False),
                num_workers=train_args['num_workers'] * self.device_count,
                pin_memory=(True if x == 'train' else False))
            for x in ['train', 'val']
        }
        self.model = vgg16dres(map_location=self.device)
        self.model.to(self.device)
        # for p in self.model.features.parameters():
        #     p.requires_grad = True
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=train_args['lr'],
                                    weight_decay=train_args['weight_decay'],
                                    amsgrad=False)
        # for _, p in zip(range(10000), next(self.model.children()).children()):
        #     p.requires_grad = False
        #     print("freeze: ", p)
        # print(self.optimizer.param_groups[0])
        self.start_epoch = 0
        self.ot_loss = OT_Loss(train_args['crop_size'], downsample_ratio,
                               train_args['norm_cood'], self.device,
                               self.logger, train_args['num_of_iter_in_ot'],
                               train_args['reg'])
        self.tv_loss = nn.L1Loss(reduction='none').to(self.device)
        self.mse = nn.MSELoss().to(self.device)
        self.mae = nn.L1Loss().to(self.device)
        self.save_list = Save_Handle(max_num=1)
        self.best_mae = np.inf
        self.best_mse = np.inf
        self.best_count = 0
        if train_args['resume']:
            self.logger.add_text(
                'log/train',
                'loading pretrained model from ' + train_args['resume'], 0)
            suf = train_args['resume'].rsplit('.', 1)[-1]
            if suf == 'tar':
                checkpoint = torch.load(train_args['resume'], self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(
                    checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
                self.best_count = checkpoint['best_count']
                self.best_mae = checkpoint['best_mae']
                self.best_mse = checkpoint['best_mse']
                print(self.best_mae, self.best_mse, self.best_count)
            elif suf == 'pth':
                self.model.load_state_dict(
                    torch.load(train_args['resume'], self.device))
        else:
            self.logger.add_text('log/train', 'random initialization', 0)
        img_cnts = {
            'val_image_count': len(self.dataloaders['val']),
            'train_image_count': len(self.dataloaders['train'])
        }
        self.logger.add_hparams({
            **self.train_args,
            **img_cnts
        }, {
            'best_mse': np.inf,
            'best_mae': np.inf,
            'best_count': 0
        },
                                run_name='hparams')