Esempio n. 1
0
    def filter_requests(requests_list: List[RequestInfo],
                        bad_condition: Callable,
                        warning_message: str = None,
                        error_message: str = None,
                        logger: Logger = None):
        """

        :param requests_list:
        :param bad_condition:
        :param warning_message:
        :param error_message:
        :param logger:
        :return:
        """
        bad_requests = []
        filtered_requests = []

        for info in requests_list:
            if bad_condition(info):
                bad_requests.append(info.request.method + ' ' +
                                    info.origin_url)
                continue
            filtered_requests.append(info)

        if not len(filtered_requests):
            if error_message and logger:
                logger.error(error_message)
            return None

        if bad_requests and warning_message and logger:
            logger.warning(f'{warning_message}: {bad_requests}')

        return filtered_requests
Esempio n. 2
0
    def training_process(self):
        self.model.train()
        logger = Logger(
            self.cfg.iter_num,
            self.cfg.print_interval) if self.cfg.gpu_id == 0 else None

        for iter_step in range(self.cfg.iter_num):
            self.optimizer.zero_grad()
            imgs, priors, gts, _ = self.train_data_loader.get_batch()
            results, targets, valid_flag = self.model(imgs, priors)
            loss = F.binary_cross_entropy(results, targets, reduction="none")
            loss = (loss * valid_flag).sum() / valid_flag.sum()

            loss.backward()
            self.optimizer.step()

            if logger is not None:
                logger.step(loss, iter_step)

            if iter_step % self.cfg.eval_interval == self.cfg.eval_interval - 1:
                self.evaluate_process()
            if iter_step % self.cfg.ckp_interval == self.cfg.ckp_interval - 1 and self.cfg.gpu_id == 0:
                torch.save(self.model.module.state_dict(),
                           "./ckp/model_{}.pth".format(iter_step))
            self.lr_sch.step()
    def _get_logger(self, logger_name, is_resume):
        logger = Logger(osp.join(self.res_dir, logger_name),
                        title='log',
                        resume=is_resume)
        cur_time = time.strftime('%Y-%m-%d %H:%M:%S',
                                 time.localtime(time.time()))
        logger.write('# ===== EXPERIMENT TIME: {} ===== #'.format(cur_time),
                     add_extra_empty_line=True)
        logger.write('# ===== CONFIG SETTINGS ===== #')
        for k, v in Config.__dict__.items():
            if not k.startswith('__') and not k.startswith('_'):
                logger.write('{} : {}'.format(k, v))
        logger.write('torch version:{}'.format(torch.__version__),
                     add_extra_empty_line=True)

        return logger
Esempio n. 4
0
def prepare_param_wordlist(arguments: argparse.Namespace, logger: Logger):
    param_wordlist = set()

    wordlist_paths = re.split('\s*,\s*', arguments.param_wordlist)

    for path in wordlist_paths:
        if not os.path.exists(path):
            logger.error(
                f'Путь "{path}" из --param-wordlist не указывает на словарь с параметрами'
            )
            continue

        with open(path) as file:
            param_wordlist |= set(
                [w.strip() for w in file.readlines() if w.strip()])

    return list(param_wordlist)
Esempio n. 5
0
    def __init__(self, network, dataloader_train, dataloader_val, dataloader_trainval, output_dir, checkpoint_dir, log_dir):
        self.net = network

        self.dataloader_train = dataloader_train
        self.dataloader_val = dataloader_val
        self.dataloader_trainval = dataloader_trainval

        self.output_dir = output_dir
        self.checkpoint_dir = checkpoint_dir

        self.logger_train = Logger(os.path.join(log_dir, 'train'))
        self.logger_val = Logger(os.path.join(log_dir, 'val'))
        self.logger_trainval = Logger(os.path.join(log_dir, 'trainval'))
Esempio n. 6
0
class TestLogger(unittest.TestCase):

    def setUp(self):
        self.logger = Logger(log_dir='logs/unittest')
        self.model = nn.Conv2d(3, 32, 3, 1)

    def tearDown(self):
        pass

    def test_logger(self):
        for it in range(100):
            self.logger.add_scalar('data/unittest', 3, it)
            self.logger.add_text('unittest', 'iter %d' % (it + 1), it)
            self.logger.add_array('unittest', np.random.rand(5, 5), it)
            self.logger.add_checkpoint('unittest', self.model.state_dict(), it)
Esempio n. 7
0
import __init__

import sys
from time import time

from lib.arguments import parse_args, is_args_valid, prepare_args
from lib.finders.finder import Finder
from lib.miners import Miner
from lib.reporter import Reporter
from lib.utils.logger import Logger
from lib.utils.request_helper import RequestHelper, RequestInfo, get_request_objects

if __name__ == '__main__':
    # Обработка аргументов командной строки
    args = parse_args()
    logger = Logger(args)

    # Проверка переданных аргументов на валидность и достаточность
    if not is_args_valid(args, logger):
        sys.exit(1)
    # Преобразование аргументов под вид, удобный для работы скрипта
    prepare_args(args, logger)

    logger.info('Обработка сырых запросов')

    start = time()

    # Преобразовываем сырые запросы в объекты типа `requests.PreparedRequest`
    prepared_requests, not_prepared_requests = get_request_objects(
        args.raw_requests, args, logger)
Esempio n. 8
0
class Trainer:
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)

        # Define Logger
        self.logger = Logger(args.save_path)

        # Define Evaluator
        self.evaluator = Evaluator(args.num_classes)

        # Define Best Prediction
        self.best_pred = 0.0

        # Define Last Epoch
        self.last_epoch = -1

        # Define DataLoader
        train_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        valid_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        target_transform = transforms.Compose([
            transforms.ToLong(),
        ])
        train_dataset = Registries.dataset_registry.__getitem__(args.dataset)(
            args.dataset_path, 'train', train_transform, target_transform)
        valid_dataset = Registries.dataset_registry.__getitem__(args.dataset)(
            args.dataset_path, 'valid', valid_transform, target_transform)

        kwargs = {
            'batch_size': args.batch_size,
            'num_workers': args.num_workers,
            'pin_memory': True
        }

        self.train_loader = DataLoader(dataset=train_dataset,
                                       shuffle=False,
                                       **kwargs)
        self.valid_loader = DataLoader(dataset=valid_dataset,
                                       shuffle=False,
                                       **kwargs)

        # Define Model
        self.model = Registries.backbone_registry.__getitem__(
            args.backbone)(num_classes=10)

        # Define Optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.init_learning_rate,
                                         momentum=0.9,
                                         dampening=0.1)

        # Define Criterion
        self.criterion = FocalLoss()

        # Define  Learning Rate Scheduler
        self.scheduler = WarmUpStepLR(self.optimizer,
                                      warm_up_end_epoch=100,
                                      step_size=50,
                                      gamma=0.1)

        # Use cuda
        if torch.cuda.is_available() and args.use_gpu:
            self.device = torch.device("cuda", args.gpu_ids[0])
            if len(args.gpu_ids) > 1:
                self.model = torch.nn.DataParallel(self.model,
                                                   device_ids=args.gpu_ids)
        else:
            self.device = torch.device("cpu")
        self.model = self.model.to(self.device)

        # Use pretrained model
        if args.pretrained_model_path is not None:
            if not os.path.isfile(args.pretrained_model_path):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.pretrained_model_path))
            else:
                checkpoint = torch.load(args.pretrained_model_path)
                if args.use_gpu and len(args.gpu_ids) > 1:
                    self.model.module.load_state_dict(checkpoint['model'])
                else:
                    self.model.load_state_dict(checkpoint['model'])
                self.scheduler.load_state_dict(checkpoint['scheduler'])
                self.best_pred = checkpoint['best_pred']
                self.optimizer = self.scheduler.optimizer
                self.last_epoch = checkpoint['last_epoch']
                print("=> loaded checkpoint '{}'".format(
                    args.pretrained_model_path))

    def train(self):
        for epoch in range(self.last_epoch + 1, self.args.num_epochs):
            self._train_a_epoch(epoch)
            if epoch % self.args.valid_step == (self.args.valid_step - 1):
                self._valid_a_epoch(epoch)

    def _train_a_epoch(self, epoch):
        print('train epoch %d' % epoch)
        total_loss = 0
        tbar = tqdm.tqdm(self.train_loader)
        self.model.train()  # change the model to train mode
        step_num = len(self.train_loader)
        for step, sample in enumerate(tbar):
            inputs, labels = sample['data'], sample[
                'label']  # get the inputs and labels from dataloader
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            if epoch == 0 and step == 0:
                self.logger.show_img_grid(inputs)
                self.logger.writer.add_graph(self.model, inputs)
            self.optimizer.zero_grad(
            )  # zero the optimizer because the gradient will accumulate in PyTorch
            outputs = self.model(inputs)  # get the output(forward)
            loss = self.criterion(outputs, labels)  # compute the loss
            loss.backward()  # back propagate the loss(backward)
            total_loss += loss.item()
            self.optimizer.step()  # update the weights
            tbar.set_description('train iteration loss= %.6f' % loss.item())
            self.logger.writer.add_scalar('train iteration loss', loss,
                                          epoch * step_num + step)
        self.logger.writer.add_scalar('train epoch loss',
                                      total_loss / step_num, epoch)
        preds = torch.argmax(outputs, dim=1)
        self.logger.add_pr_curve_tensorboard('pr curve', labels, preds)
        self.scheduler.step()  # update the learning rate
        self.saver.save_checkpoint(
            {
                'scheduler': self.scheduler.state_dict(),
                'model': self.model.state_dict(),
                'best_pred': self.best_pred,
                'last_epoch': epoch
            }, 'current_checkpoint.pth')

    def _valid_a_epoch(self, epoch):
        print('valid epoch %d' % epoch)
        tbar = tqdm.tqdm(self.valid_loader)
        self.model.eval()  # change the model to eval mode
        with torch.no_grad():
            for step, sample in enumerate(tbar):
                inputs, labels = sample['data'], sample[
                    'label']  # get the inputs and labels from dataloader
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)  # get the output(forward)
                predicts = torch.argmax(outputs, dim=1)
                self.evaluator.add_batch(labels.cpu().numpy(),
                                         predicts.cpu().numpy())
        new_pred = self.evaluator.Mean_Intersection_over_Union()
        print()

        if new_pred > self.best_pred:
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'scheduler': self.scheduler.state_dict(),
                    'model': self.model.state_dict(),
                    'best_pred': self.best_pred,
                    'last_epoch': epoch
                }, 'best_checkpoint.pth')
            self.saver.save_parameters()
Esempio n. 9
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument('--gpu', default='0', type=str)
    parser.add_argument(
        "--config_file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()
    
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # build model, optimizer and scheduler
    model = make_model(cfg)
    model = model.to(cfg.DEVICE)
    optimizer = build_optimizer(cfg, model, mode=cfg.SOLVER.TRAIN_MODULE)
    print('optimizer built!')
    # NOTE: add separate optimizers to train single object predictor and interaction predictor
    
    if cfg.USE_WANDB:
        logger = Logger("FOL",
                        cfg,
                        project = cfg.PROJECT,
                        viz_backend="wandb"
                        )
    else:
        logger = logging.Logger("FOL")

    dataloader_params ={
            "batch_size": cfg.SOLVER.BATCH_SIZE,
            "shuffle": True,
            "num_workers": cfg.DATALOADER.NUM_WORKERS
            }
    
    # get dataloaders
    train_dataloader = make_dataloader(cfg, 'train')
    val_dataloader = make_dataloader(cfg, 'val')
    test_dataloader = make_dataloader(cfg, 'test')
    print('Dataloader built!')
    # get train_val_test engines
    do_train, do_val, inference = build_engine(cfg)
    print('Training engine built!')
    if hasattr(logger, 'run_id'):
        run_id = logger.run_id
    else:
        run_id = 'no_wandb'

    save_checkpoint_dir = os.path.join(cfg.CKPT_DIR, run_id)
    if not os.path.exists(save_checkpoint_dir):
        os.makedirs(save_checkpoint_dir)
    
    # NOTE: hyperparameter scheduler
    model.param_scheduler = ParamScheduler()
    model.param_scheduler.create_new_scheduler(
                                        name='kld_weight',
                                        annealer=sigmoid_anneal,
                                        annealer_kws={
                                            'device': cfg.DEVICE,
                                            'start': 0,
                                            'finish': 100.0,
                                            'center_step': 400.0,
                                            'steps_lo_to_hi': 100.0, 
                                        })
    
    model.param_scheduler.create_new_scheduler(
                                        name='z_logit_clip',
                                        annealer=sigmoid_anneal,
                                        annealer_kws={
                                            'device': cfg.DEVICE,
                                            'start': 0.05,
                                            'finish': 5.0, 
                                            'center_step': 300.0,
                                            'steps_lo_to_hi': 300.0 / 5.
                                        })
    
    
    if cfg.SOLVER.scheduler == 'exp':
        # exponential schedule
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=cfg.SOLVER.GAMMA)
    elif cfg.SOLVER.scheduler == 'plateau':
        # Plateau scheduler
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=5,
                                                            min_lr=1e-07, verbose=1)
    else:
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 40], gamma=0.2)
    early_stop = EarlyStopping(min_delta=0.1, patience=5, verbose=1)
                                                        
    print('Schedulers built!')

    for epoch in range(cfg.SOLVER.MAX_EPOCH):
        logger.info("Epoch:{}".format(epoch))
        do_train(cfg, epoch, model, optimizer, train_dataloader, cfg.DEVICE, logger=logger, lr_scheduler=lr_scheduler)
        val_loss = do_val(cfg, epoch, model, val_dataloader, cfg.DEVICE, logger=logger)
        if (epoch+1) % 1 == 0:
            inference(cfg, epoch, model, test_dataloader, cfg.DEVICE, logger=logger, eval_kde_nll=False)
            
        torch.save(model.state_dict(), os.path.join(save_checkpoint_dir, 'Epoch_{}.pth'.format(str(epoch).zfill(3))))

        # update LR
        if cfg.SOLVER.scheduler != 'exp':
            lr_scheduler.step(val_loss)
            cur_exp["params"][p] = params[p]
        cur_exp["results"] = {}
        cur_exp["results"]["cluster_acc"] = round(cluster_acc, 3)
        cur_exp["results"]["cluster_ari"] = round(cluster_score, 3)
        cur_exp["results"]["cluster_nmi"] = round(cluster_nmi, 3)
        cur_exp["timing"] = {}
        cur_exp["timing"]["scattering"] = t1 - t0
        cur_exp["timing"]["preprocessing"] = t2 - t1
        cur_exp["timing"]["clustering"] = t3 - t2
        cur_exp["timing"]["total"] = t3 - t0
        print_(cur_exp, verbose=verbose)
        data[f"experiment_{n_exps}"] = cur_exp
        json.dump(data, f)

    return


if __name__ == "__main__":
    os.system("clear")
    logger = Logger(exp_path=os.getcwd())

    dataset_name, verbose, random_seed, params = process_clustering_arguments()
    print_(params, message_type="new_exp")

    clustering_experiment(dataset_name=dataset_name,
                          params=params,
                          verbose=verbose,
                          random_seed=random_seed)

#
Esempio n. 11
0
 def __init__(self, config):
     self.config = config
     self.device = torch.device('cuda' if torch.cuda.
                                is_available() else 'cpu')
     self.logger = Logger(config["stats folder"])
Esempio n. 12
0
    'icon': os.path.join(addonPath, 'resources/images/icon.png')
}

# Create cache path
cacheDataPath = os.path.join(userDataPath, 'cache')
if not os.path.exists(cacheDataPath):
    xbmcvfs.mkdir(cacheDataPath)
cacheMediaPath = os.path.join(cacheDataPath, 'media')
if not os.path.exists(cacheMediaPath):
    xbmcvfs.mkdir(cacheMediaPath)
tmpDataPath = os.path.join(userDataPath, 'temp')
if not os.path.exists(tmpDataPath):
    xbmcvfs.mkdir(tmpDataPath)

# Properties
log = Logger(logPath)
logger = log.getLogger(__name__)

# Functions
INPUT_ALPHANUM = xbmcgui.INPUT_ALPHANUM
dialog = xbmcgui.Dialog()
pDialog = xbmcgui.DialogProgress()
pDialogBG = xbmcgui.DialogProgressBG
listItem = xbmcgui.ListItem
addDirectoryItem = xbmcplugin.addDirectoryItem
endOfDirectory = xbmcplugin.endOfDirectory
execbuiltin = xbmc.executebuiltin
m3uQueue = Queue()
m3uMaxThreads = 20

# Constants
Esempio n. 13
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)

        # Define Logger
        self.logger = Logger(args.save_path)

        # Define Evaluator
        self.evaluator = Evaluator(args.num_classes)

        # Define Best Prediction
        self.best_pred = 0.0

        # Define Last Epoch
        self.last_epoch = -1

        # Define DataLoader
        train_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        valid_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        target_transform = transforms.Compose([
            transforms.ToLong(),
        ])
        train_dataset = Registries.dataset_registry.__getitem__(args.dataset)(
            args.dataset_path, 'train', train_transform, target_transform)
        valid_dataset = Registries.dataset_registry.__getitem__(args.dataset)(
            args.dataset_path, 'valid', valid_transform, target_transform)

        kwargs = {
            'batch_size': args.batch_size,
            'num_workers': args.num_workers,
            'pin_memory': True
        }

        self.train_loader = DataLoader(dataset=train_dataset,
                                       shuffle=False,
                                       **kwargs)
        self.valid_loader = DataLoader(dataset=valid_dataset,
                                       shuffle=False,
                                       **kwargs)

        # Define Model
        self.model = Registries.backbone_registry.__getitem__(
            args.backbone)(num_classes=10)

        # Define Optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.init_learning_rate,
                                         momentum=0.9,
                                         dampening=0.1)

        # Define Criterion
        self.criterion = FocalLoss()

        # Define  Learning Rate Scheduler
        self.scheduler = WarmUpStepLR(self.optimizer,
                                      warm_up_end_epoch=100,
                                      step_size=50,
                                      gamma=0.1)

        # Use cuda
        if torch.cuda.is_available() and args.use_gpu:
            self.device = torch.device("cuda", args.gpu_ids[0])
            if len(args.gpu_ids) > 1:
                self.model = torch.nn.DataParallel(self.model,
                                                   device_ids=args.gpu_ids)
        else:
            self.device = torch.device("cpu")
        self.model = self.model.to(self.device)

        # Use pretrained model
        if args.pretrained_model_path is not None:
            if not os.path.isfile(args.pretrained_model_path):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.pretrained_model_path))
            else:
                checkpoint = torch.load(args.pretrained_model_path)
                if args.use_gpu and len(args.gpu_ids) > 1:
                    self.model.module.load_state_dict(checkpoint['model'])
                else:
                    self.model.load_state_dict(checkpoint['model'])
                self.scheduler.load_state_dict(checkpoint['scheduler'])
                self.best_pred = checkpoint['best_pred']
                self.optimizer = self.scheduler.optimizer
                self.last_epoch = checkpoint['last_epoch']
                print("=> loaded checkpoint '{}'".format(
                    args.pretrained_model_path))
Esempio n. 14
0
class SolverWrapper(object):
    """
    A wrapper class for the training process

    """

    def __init__(self, network, dataloader_train, dataloader_val, dataloader_trainval, output_dir, checkpoint_dir, log_dir):
        self.net = network

        self.dataloader_train = dataloader_train
        self.dataloader_val = dataloader_val
        self.dataloader_trainval = dataloader_trainval

        self.output_dir = output_dir
        self.checkpoint_dir = checkpoint_dir

        self.logger_train = Logger(os.path.join(log_dir, 'train'))
        self.logger_val = Logger(os.path.join(log_dir, 'val'))
        self.logger_trainval = Logger(os.path.join(log_dir, 'trainval'))

    def snapshot(self, index):
        net = self.net

        # Store the model snapshot
        filename = 'step_{:d}'.format(index) + '.pth'
        filename = os.path.join(self.checkpoint_dir, filename)
        torch.save(net.state_dict(), filename)
        print('Write snapshot to {:s}'.format(filename))

        # also store some meta info, random state etd.
        nfilename = 'step_{:d}'.format(index) + '.pkl'
        nfilename = os.path.join(self.checkpoint_dir, nfilename)

        # current state of numpy random
        with open(nfilename, 'wb') as fid:
            pickle.dump(index, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def from_snapshot(self, sfile, nfile):
        print('Restoring mode snapshots from {:s}'.format(sfile))
        pretrained_dict = torch.load(str(sfile))
        model_dict = self.net.state_dict()

        if cfg.LOAD_BACKBONE:
            pretrained_dict_backbone1 = {k: v for k, v in pretrained_dict.items() if ('geometry' in k or 'combine'in k)}
            model_dict.update(pretrained_dict_backbone1)

        if cfg.LOAD_RPN:
            pretrained_dict_rpn = {k: v for k, v in pretrained_dict.items() if 'rpn' in k}
            model_dict.update(pretrained_dict_rpn)

        if cfg.LOAD_CLASS:
            if cfg.NYUV2_FINETUNE:
                pretrained_dict_class = {k: v for k, v in pretrained_dict.items() if ('classifier' in k and 'classifier_cls' not in k and 'classifier_bbox' not in k)}
            else:
                pretrained_dict_class = {k: v for k, v in pretrained_dict.items() if 'classifier' in k}
            model_dict.update(pretrained_dict_class)

        # enet is loaded already in creat_architecture
        if cfg.USE_IMAGES:
            pretrained_dict_image = {k: v for k, v in pretrained_dict.items() if 'color' in k}
            model_dict.update(pretrained_dict_image)


        self.net.load_state_dict(model_dict)
        
        print('Restored')

        with open(nfile, 'rb') as fid:
            last_iter = pickle.load(fid)

        if isinstance(last_iter, list):
            current_snapshot_epoch = last_iter[0]
            iter = last_iter[1]
            last_iter = len(self.dataloader_train) * current_snapshot_epoch + iter

        return last_iter


    def construct_optimizer(self, lr):
        # Optimizer
        params = []
        total_weights = 0
        for key, value in dict(self.net.named_parameters()).items():
            if value.requires_grad:
                if 'bias' in key:
                    params += [{'params':[value], 'lr':lr*(cfg.DOUBLE_BIAS + 1), 'weight_decay': cfg.BIAS_DECAY and cfg.WEIGHT_DECAY}]
                else:
                    params += [{'params':[value], 'lr':lr, 'weight_decay': cfg.WEIGHT_DECAY}]
                total_weights += value.numel()
        print("total weights: {}".format(total_weights))
        self.optimizer = torch.optim.SGD(params, momentum=cfg.MOMENTUM)

        #set up lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def find_previous(self):
        sfiles = os.path.join(self.checkpoint_dir, 'step_*.pth')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)

        # Get the snapshot name in pytorch
        redfiles = []
        for stepsize in cfg.STEPSIZE:
            redfiles.append(os.path.join(self.checkpoint_dir, 'step_{:d}.pth'.format(stepsize + 1)))
        sfiles = [ss for ss in sfiles if ss not in redfiles]

        nfiles = os.path.join(self.checkpoint_dir, 'step_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        redfiles = [redfile.replace('.pth', '.pkl') for redfile in redfiles]
        nfiles = [nn for nn in nfiles if nn not in redfiles]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        return lsf, nfiles, sfiles

    def initialize(self):
        """

        :return:
        """
        np_paths = []
        ss_paths = []

        lr = cfg.LEARNING_RATE
        stepsizes = list(cfg.STEPSIZE)

        return lr, 0, stepsizes, np_paths, ss_paths

    def restore(self, sfile, nfile):
        np_paths = [nfile]
        ss_paths = [sfile]

        last_iter = self.from_snapshot(sfile, nfile)

        lr_scale = 1
        stepsizes = []
        
        for stepsize in cfg.STEPSIZE:
            if last_iter > stepsize:
                lr_scale *= cfg.GAMMA
            else:
                stepsizes.append(stepsize)

        lr = cfg.LEARNING_RATE * lr_scale
        return lr, last_iter, stepsizes, np_paths, ss_paths

    def remove_snapshot(self):
        to_remove = len(self.np_paths) - cfg.SNAPSHOT_KEPT
        for c in range(to_remove):
            nfile = self.np_paths[0]
            os.remove(str(nfile))
            self.np_paths.remove(nfile)

        to_remove = len(self.ss_paths) - cfg.SNAPSHOT_KEPT
        for c in range(to_remove):
            sfile = self.ss_paths[0]
            os.remove(str(sfile))
            self.ss_paths.remove(sfile)

    def scale_lr(self, optimizer, lr):
        """
        Scale the learning rate of the optimizer

        :param optimizer:
        :param scale:
        :return:
        """
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr 

    def fix_eval_parts(self):
        # FIX PART
        for name, var in self.net.named_parameters():
            if cfg.FIX_BACKBONE and ('geometry' in name or 'color' in name or 'combine' in name) and (not 'mask_backbone' in name):
                var.requires_grad = False
            elif cfg.FIX_RPN and 'rpn' in name:
                var.requires_grad = False
            elif cfg.FIX_CLASS and 'classifier' in name:
                var.requires_grad = False
            elif cfg.FIX_ENET and 'enet' in name:
                var.requires_grad = False

            if cfg.NYUV2_FINETUNE and ('classfier_cls' in name or 'classifier_bbox' in name or 'classifier.4' in name):
                var.requires_grad = True

    def train_model(self, epochs):
        #1. construct the computation graph
        self.net.init_modules()

        #save net structure to data folder
        net_f = open(os.path.join(self.output_dir, 'nn.txt'), 'w')
        net_f.write(str(self.net))
        net_f.close()

        #find previous snapshot 
        lsf, nfiles, sfiles = self.find_previous()

        #2. restore weights
        if lsf == 0:
            lr, last_iter, stepsizes, self.np_paths, self.ss_paths = self.initialize()
        else:
            lr, last_iter, stepsizes, self.np_paths, self.ss_paths = self.restore(str(sfiles[-1]),
                                                                                 str(nfiles[-1]))
        #3. fix weights and eval mode
        self.fix_eval_parts()

        # construct optimizer
        self.construct_optimizer(lr)

        if len(stepsizes) != 0:
            next_stepsize = stepsizes.pop(0)
        else:
            next_stepsize = -1

        train_timer = Timer()
        current_snapshot_epoch = int(last_iter / len(self.dataloader_train))
        for epoch in range(current_snapshot_epoch, epochs):
            print("start epoch {}".format(epoch))
            with output(initial_len=9, interval=0) as content:
                for iter, blobs in enumerate(tqdm(self.dataloader_train)):
                    last_iter += 1
                    # adjust learning rate
                    if last_iter == next_stepsize:
                        lr *= cfg.GAMMA
                        self.scale_lr(self.optimizer, lr)
                        if len(stepsizes) != 0:
                            next_stepsize = stepsizes.pop(0)

                    batch_size = blobs['data'].shape[0]
                    if len(blobs['gt_box']) < batch_size: #invalid sample
                        continue
                    train_timer.tic()
                    # IMAGE PART
                    if cfg.USE_IMAGES:
                        grid_shape = blobs['data'].shape[-3:]
                        projection_helper = ProjectionHelper(cfg.INTRINSIC, cfg.PROJ_DEPTH_MIN, cfg.PROJ_DEPTH_MAX, cfg.DEPTH_SHAPE, grid_shape, cfg.VOXEL_SIZE)
                        proj_mapping = [[projection_helper.compute_projection(d.cuda(), c.cuda(), t.cuda()) for d, c, t in zip(blobs['nearest_images']['depths'][i], blobs['nearest_images']['poses'][i], blobs['nearest_images']['world2grid'][i])] for i in range(batch_size)]

                        jump_flag = False
                        for i in range(batch_size):
                            if None in proj_mapping[i]: #invalid sample
                                jump_flag = True
                                break
                        if jump_flag:
                            continue
                        
                        blobs['proj_ind_3d'] = []
                        blobs['proj_ind_2d'] = []
                        for i in range(batch_size):
                            proj_mapping0, proj_mapping1 = zip(*proj_mapping[i])
                            blobs['proj_ind_3d'].append(torch.stack(proj_mapping0))
                            blobs['proj_ind_2d'].append(torch.stack(proj_mapping1))

                        
                    self.net.forward(blobs)
                    self.optimizer.zero_grad()
                    self.net._losses["total_loss"].backward()
                    self.optimizer.step()

                    train_timer.toc()

                    # Display training information
                    if iter % (cfg.DISPLAY) == 0:
                        self.log_print(epoch*len(self.dataloader_train)+iter, lr, content, train_timer.average_time())
                    self.net.delete_intermediate_states()

                    # validate if satisfying the time criterion
                    if train_timer.total_time() / 3600 >= cfg.VAL_TIME:
                        print('------------------------VALIDATION------------------------------')
                        self.validation(last_iter, 'val')
                        print('------------------------TRAINVAL--------------------------------')
                        self.validation(last_iter, 'trainval')

                        # snapshot
                        if cfg.VAL_TIME > 0.0:
                            ss_path, np_path = self.snapshot(last_iter)
                            self.np_paths.append(np_path)
                            self.ss_paths.append(ss_path)

                            #remove old snapshots if too many
                            if len(self.np_paths) > cfg.SNAPSHOT_KEPT and cfg.SNAPSHOT_KEPT:
                                self.remove_snapshot()

                        train_timer.clean_total_time()


    def log_print(self, index, lr, content, average_time):
        total_loss = 0.0
        content[0] = 'tqdm eats it'
        if cfg.USE_RPN:
            if cfg.NUM_ANCHORS_LEVEL1 != 0:
                content[1] = '>>> rpn_loss_cls_level1: {:.9f}, rpn_loss_box_level1: {:.9f}, rpn_loss_level1: {:.9f}'. \
                        format(self.net._losses['rpn_cross_entropy_level1'].item(), self.net._losses['rpn_loss_box_level1'].item(), 
                            self.net._losses['rpn_cross_entropy_level1'].item() + self.net._losses['rpn_loss_box_level1'].item())
                #log
                self.logger_train.scalar_summary('rpn_loss_cls_level1', self.net._losses['rpn_cross_entropy_level1'].item(), index)
                self.logger_train.scalar_summary('rpn_loss_box_level1', self.net._losses['rpn_loss_box_level1'].item(), index)
                self.logger_train.scalar_summary('rpn_loss_level1', self.net._losses['rpn_cross_entropy_level1'].item() + self.net._losses['rpn_loss_box_level1'].item(), index)

            if cfg.NUM_ANCHORS_LEVEL2 != 0:
                content[2] = '>>> rpn_loss_cls_level2: {:.9f}, rpn_loss_box_level2: {:.9f}, rpn_loss_level2: {:.9f}'. \
                        format(self.net._losses['rpn_cross_entropy_level2'].item(), self.net._losses['rpn_loss_box_level2'].item(), 
                            self.net._losses['rpn_cross_entropy_level2'].item() + self.net._losses['rpn_loss_box_level2'].item())

                self.logger_train.scalar_summary('rpn_loss_cls_level2', self.net._losses['rpn_cross_entropy_level2'].item(), index)
                self.logger_train.scalar_summary('rpn_loss_box_level2', self.net._losses['rpn_loss_box_level2'].item(), index)
                self.logger_train.scalar_summary('rpn_loss_level2', self.net._losses['rpn_cross_entropy_level2'].item() + self.net._losses['rpn_loss_box_level2'].item(), index)

            if cfg.NUM_ANCHORS_LEVEL3 != 0:
                content[3] = '>>> rpn_loss_cls_level3: {:.9f}, rpn_loss_box_level3: {:.9f}, rpn_loss_level3: {:.9f}'. \
                        format(self.net._losses['rpn_cross_entropy_level3'].item(), self.net._losses['rpn_loss_box_level3'].item(), 
                            self.net._losses['rpn_cross_entropy_level3'].item() + self.net._losses['rpn_loss_box_level3'].item())

                self.logger_train.scalar_summary('rpn_loss_cls_level3', self.net._losses['rpn_cross_entropy_level3'].item(), index)
                self.logger_train.scalar_summary('rpn_loss_box_level3', self.net._losses['rpn_loss_box_level3'].item(), index)
                self.logger_train.scalar_summary('rpn_loss_level3', self.net._losses['rpn_cross_entropy_level3'].item() + self.net._losses['rpn_loss_box_level3'].item(), index)

        if cfg.USE_CLASS:
            content[4] = '>>> loss_cls: {:.9f}, loss_box: {:.9f}, classification_loss: {:.9f}'. \
                    format(self.net._losses['cross_entropy'].item(), self.net._losses['loss_box'].item(), 
                           self.net._losses['cross_entropy'].item() + self.net._losses['loss_box'].item())

            self.logger_train.scalar_summary('classification_loss_cls', self.net._losses['cross_entropy'].item(), index)
            self.logger_train.scalar_summary('classification_loss_box', self.net._losses['loss_box'].item(), index)
            self.logger_train.scalar_summary('classification_loss', self.net._losses['cross_entropy'].item() + self.net._losses['loss_box'].item(), index)
        if cfg.USE_MASK:
            content[5] = '>>> mask_loss: {:.9f}'.format(self.net._losses['loss_mask'].item())
            self.logger_train.scalar_summary('mask_loss', self.net._losses['loss_mask'].item(), index)

        content[6] = '>>> total_loss: {:.9f}, lr: {:.6f}, iteration time: {:.3f}s / iter'.format(self.net._losses['total_loss'].item(), lr, average_time)
        self.logger_train.scalar_summary('total_loss', self.net._losses['total_loss'], index)


    def validation(self, index, mode):
        #####################################
        # Preparation
        #####################################
        #-------------------------------
        # metric
        #-------------------------------
        mAP_RPN = Evaluate_metric(1, overlap_threshold=cfg.MAP_THRESH)
        mAP_CLASSIFICATION = Evaluate_metric(cfg.NUM_CLASSES, ignore_class=[0], overlap_threshold=cfg.MAP_THRESH)
        mAP_MASK = Evaluate_metric(cfg.NUM_CLASSES, ignore_class=[0], overlap_threshold=cfg.MAP_THRESH)
        if mode == 'val':
            data_loader = self.dataloader_val
            data_logger = self.logger_val
        elif mode == 'trainval':
            data_loader = self.dataloader_trainval
            data_logger = self.logger_trainval

        ####################################
        # Accumulate data
        ####################################
        timer = Timer()
        timer.tic()
        print('starting validation....')
        for iter, blobs in enumerate(tqdm(data_loader)):
            # if no box: skip
            if len(blobs['gt_box']) == 0:
                continue

            if cfg.USE_IMAGES:
                grid_shape = blobs['data'].shape[-3:]
                projection_helper = ProjectionHelper(cfg.INTRINSIC, cfg.PROJ_DEPTH_MIN, cfg.PROJ_DEPTH_MAX, cfg.DEPTH_SHAPE, grid_shape, cfg.VOXEL_SIZE)
                proj_mapping = [projection_helper.compute_projection(d.cuda(), c.cuda(), t.cuda()) for d, c, t in zip(blobs['nearest_images']['depths'][0], blobs['nearest_images']['poses'][0], blobs['nearest_images']['world2grid'][0])]

                if None in proj_mapping: #invalid sample
                    continue
                
                blobs['proj_ind_3d'] = []
                blobs['proj_ind_2d'] = []
                proj_mapping0, proj_mapping1 = zip(*proj_mapping)
                blobs['proj_ind_3d'].append(torch.stack(proj_mapping0))
                blobs['proj_ind_2d'].append(torch.stack(proj_mapping1))

            self.net.forward(blobs, 'TEST', [])
            #--------------------------------------
            # RPN: loss, metric 
            #--------------------------------------
            if cfg.USE_RPN:
                # (n, 6)
                gt_box = blobs['gt_box'][0].numpy()[:, 0:6]
                gt_box_label = np.zeros(gt_box.shape[0])

                try:
                    pred_box_num = (self.net._predictions['roi_scores'][0][:, 0] > cfg.ROI_THRESH).nonzero().size(0)
                    pred_box = self.net._predictions['rois'][0].cpu().numpy()[:pred_box_num]
                    pred_box_label = np.zeros(pred_box_num) 
                    pred_box_score = self.net._predictions['roi_scores'][0].cpu().numpy()[:pred_box_num, 0]
                except:
                    pred_box = self.net._predictions['rois'][0].cpu().numpy()[:1]
                    pred_box_label = np.zeros(1)
                    pred_box_score = self.net._predictions['roi_scores'][0].cpu().numpy()[:1, 0]

                #evaluation metric 
                mAP_RPN.evaluate(pred_box,
                                 pred_box_label,
                                 pred_box_score,
                                 gt_box,
                                 gt_box_label)

            #--------------------------------------
            # Classification: loss, metric 
            #--------------------------------------
            if cfg.USE_CLASS:
                # groundtruth
                gt_box = blobs['gt_box'][0].numpy()[:, 0:6]
                gt_class = blobs['gt_box'][0][:, 6].numpy()

                # predictions
                pred_class = self.net._predictions['cls_pred'].data.cpu().numpy()

                # only predictions['rois'] is list and is Tensor / others are no list and Variable
                rois = self.net._predictions['rois'][0].cpu()
                box_reg_pre = self.net._predictions["bbox_pred"].data.cpu().numpy()
                box_reg = np.zeros((box_reg_pre.shape[0], 6))
                pred_conf_pre = self.net._predictions['cls_prob'].data.cpu().numpy()
                pred_conf = np.zeros((pred_conf_pre.shape[0]))


                for pred_ind in range(pred_class.shape[0]):
                    box_reg[pred_ind, :] = box_reg_pre[pred_ind, pred_class[pred_ind]*6:(pred_class[pred_ind]+1)*6]
                    pred_conf[pred_ind] = pred_conf_pre[pred_ind, pred_class[pred_ind]]

                pred_box = bbox_transform_inv(rois, torch.from_numpy(box_reg).float())
                pred_box = clip_boxes(pred_box, self.net._scene_info[:3]).numpy()

                # pickup
                sort_index = []
                for conf_index in range(pred_conf.shape[0]):
                    if pred_conf[conf_index] > cfg.CLASS_THRESH:
                        sort_index.append(True)
                    else:
                        sort_index.append(False)

                # eliminate bad box
                for idx, box in enumerate(pred_box):
                    if round(box[0]) >= round(box[3]) or round(box[1]) >= round(box[4]) or round(box[2]) >= round(box[5]):
                        sort_index[idx] = False

                if len(pred_box[sort_index]) == 0:
                    print('no pred box')

                if iter < cfg.VAL_NUM:
                    os.makedirs('{}/{}'.format(cfg.VAL_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), exist_ok=True)
                    np.save('{}/{}/pred_class'.format(cfg.VAL_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), pred_class)
                    np.save('{}/{}/pred_conf'.format(cfg.VAL_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), pred_conf)
                    np.save('{}/{}/pred_box'.format(cfg.VAL_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), pred_box)
                    np.save('{}/{}/scene'.format(cfg.VAL_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), np.where(blobs['data'][0,0].numpy() <= 1, 1, 0))
                    np.save('{}/{}/gt_class'.format(cfg.VAL_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), gt_class)
                    np.save('{}/{}/gt_box'.format(cfg.VAL_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), gt_box)

                mAP_CLASSIFICATION.evaluate(
                        pred_box[sort_index],
                        pred_class[sort_index],
                        pred_conf[sort_index],
                        gt_box,
                        gt_class)

            #--------------------------------------
            # MASK: loss, metric 
            #--------------------------------------
            if cfg.USE_MASK:
                # gt data
                gt_box = blobs['gt_box'][0].numpy()[:, 0:6]
                gt_class = blobs['gt_box'][0][:, 6].numpy()
                gt_mask = blobs['gt_mask'][0]

                pred_class = self.net._predictions['cls_pred'].data.cpu().numpy()
                pred_conf = np.zeros((pred_class.shape[0]))
                for pred_ind in range(pred_class.shape[0]):
                    pred_conf[pred_ind] = self.net._predictions['cls_prob'].data.cpu().numpy()[pred_ind, pred_class.data[pred_ind]]

                # pickup
                sort_index = pred_conf > cfg.CLASS_THRESH

                # eliminate bad box
                for idx, box in enumerate(pred_box):
                    if round(box[0]) >= round(box[3]) or round(box[1]) >= round(box[4]) or round(box[2]) >= round(box[5]):
                        sort_index[idx] = False

                pred_mask = []
                mask_ind = 0
                for ind, cls in enumerate(pred_class):
                    if sort_index[ind]:
                        mask = self.net._predictions['mask_pred'][0][mask_ind][0][cls].data.cpu().numpy()
                        mask = np.where(mask >=cfg.MASK_THRESH, 1, 0).astype(np.float32)
                        pred_mask.append(mask)
                        mask_ind += 1

                if iter < cfg.VAL_NUM: 
                    pickle.dump(pred_mask, open('{}/{}/pred_mask'.format(cfg.VAL_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), 'wb'))
                    pickle.dump(sort_index, open('{}/{}/pred_mask_index'.format(cfg.VAL_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), 'wb'))
                    pickle.dump(gt_mask, open('{}/{}/gt_mask'.format(cfg.VAL_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), 'wb'))

                mAP_MASK.evaluate_mask(
                        pred_box[sort_index],
                        pred_class[sort_index],
                        pred_conf[sort_index],
                        pred_mask,
                        gt_box,
                        gt_class, 
                        gt_mask, 
                        self.net._scene_info)

            self.net.delete_intermediate_states()
        timer.toc()
        print('It took {:.3f}s for Validation on chunks'.format(timer.total_time()))

        ###################################
        # Summary
        ###################################
        if cfg.USE_RPN:
            mAP_RPN.finalize()
            print('AP of RPN: {}'.format(mAP_RPN.mAP()))
            data_logger.scalar_summary('AP_ROI', mAP_RPN.mAP(), index)

        if cfg.USE_CLASS:
            mAP_CLASSIFICATION.finalize()
            print('mAP of CLASSIFICATION: {}'.format(mAP_CLASSIFICATION.mAP()))
            for class_ind in range(cfg.NUM_CLASSES):
                if class_ind not in mAP_CLASSIFICATION.ignore_class:
                    print('class {}: {}'.format(class_ind, mAP_CLASSIFICATION.AP(class_ind)))
            data_logger.scalar_summary('mAP_CLASSIFICATION', mAP_CLASSIFICATION.mAP(), index)

        if cfg.USE_MASK:
            mAP_MASK.finalize()
            print('mAP of mask: {}'.format(mAP_MASK.mAP()))
            for class_ind in range(cfg.NUM_CLASSES):
                if class_ind not in mAP_MASK.ignore_class:
                    print('class {}: {}'.format(class_ind, mAP_MASK.AP(class_ind)))
            data_logger.scalar_summary('mAP_MASK', mAP_MASK.mAP(), index)

    @staticmethod
    def benchmark(net, data_loader, data_logger):
        #####################################
        # Preparation
        #####################################
        os.makedirs(cfg.TEST_SAVE_DIR, exist_ok=True)

        ####################################
        # Accumulate data
        ####################################
        timer = Timer()
        timer.tic()
        print('starting test on whole scan....')
        for iter, blobs in enumerate(tqdm(data_loader)):
            FLAG_EXIST_CLASS = False
            if os.path.isfile('{}/{}/pred_box.npy'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12])):
                pred_class = np.load('{}/{}/pred_class.npy'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]))
                pred_conf = np.load('{}/{}/pred_conf.npy'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]))
                pred_box = np.load('{}/{}/pred_box.npy'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]))
                FLAG_EXIST_CLASS = True

            #--------------------------------------
            # Classification: loss, metric 
            #--------------------------------------
            if not FLAG_EXIST_CLASS:
                # color proj
                killing_inds = None
                if cfg.USE_IMAGES:
                    grid_shape = blobs['data'].shape[-3:]
                    projection_helper = ProjectionHelper(cfg.INTRINSIC, cfg.PROJ_DEPTH_MIN, cfg.PROJ_DEPTH_MAX, cfg.DEPTH_SHAPE, grid_shape, cfg.VOXEL_SIZE)
                    if grid_shape[0]*grid_shape[1]*grid_shape[2] > cfg.MAX_VOLUME or blobs['nearest_images']['depths'][0].shape[0] > cfg.MAX_IMAGE:
                        proj_mapping = [projection_helper.compute_projection(d, c, t) for d, c, t in zip(blobs['nearest_images']['depths'][0], blobs['nearest_images']['poses'][0], blobs['nearest_images']['world2grid'][0])]
                    else:
                        proj_mapping = [projection_helper.compute_projection(d.cuda(), c.cuda(), t.cuda()) for d, c, t in zip(blobs['nearest_images']['depths'][0], blobs['nearest_images']['poses'][0], blobs['nearest_images']['world2grid'][0])]
                        
                    killing_inds = []
                    real_proj_mapping = []
                    if None in proj_mapping: #invalid sample
                        for killing_ind, killing_item in enumerate(proj_mapping):
                            if killing_item == None:
                                killing_inds.append(killing_ind)
                            else:
                                real_proj_mapping.append(killing_item)
                        print('{}: (invalid sample: no valid projection)'.format(blobs['id']))
                    else:
                        real_proj_mapping = proj_mapping
                    blobs['proj_ind_3d'] = []
                    blobs['proj_ind_2d'] = []
                    proj_mapping0, proj_mapping1 = zip(*real_proj_mapping)
                    blobs['proj_ind_3d'].append(torch.stack(proj_mapping0))
                    blobs['proj_ind_2d'].append(torch.stack(proj_mapping1))

                net.forward(blobs, 'TEST', killing_inds)

                # test with detection pipeline
                pred_class = net._predictions['cls_pred'].data.cpu().numpy()
                rois = net._predictions['rois'][0].cpu()
                box_reg_pre = net._predictions["bbox_pred"].data.cpu().numpy()
                box_reg = np.zeros((box_reg_pre.shape[0], 6))
                pred_conf_pre = net._predictions['cls_prob'].data.cpu().numpy()
                pred_conf = np.zeros((pred_conf_pre.shape[0]))

                for pred_ind in range(pred_class.shape[0]):
                    box_reg[pred_ind, :] = box_reg_pre[pred_ind, pred_class[pred_ind]*6:(pred_class[pred_ind]+1)*6]
                    pred_conf[pred_ind] = pred_conf_pre[pred_ind, pred_class[pred_ind]]

                pred_box = bbox_transform_inv(rois, torch.from_numpy(box_reg).float())
                pred_box = clip_boxes(pred_box, net._scene_info[:3]).numpy()

                # pickup
                sort_index = []
                for conf_index in range(pred_conf.shape[0]):
                    if pred_conf[conf_index] > cfg.CLASS_THRESH:
                        sort_index.append(True)
                    else:
                        sort_index.append(False)

                # eliminate bad box
                for idx, box in enumerate(pred_box):
                    if round(box[0]) >= round(box[3]) or round(box[1]) >= round(box[4]) or round(box[2]) >= round(box[5]):
                        sort_index[idx] = False

                os.makedirs('{}/{}'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), exist_ok=True)
                np.save('{}/{}/pred_class'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), pred_class)
                np.save('{}/{}/pred_conf'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), pred_conf)
                np.save('{}/{}/pred_box'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), pred_box)
                np.save('{}/{}/scene'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), np.where(blobs['data'][0,0].numpy() <= 1, 1, 0))

            if cfg.USE_MASK:
                # pickup
                sort_index = []
                for conf_index in range(pred_conf.shape[0]):
                    if pred_conf[conf_index] > cfg.CLASS_THRESH:
                        sort_index.append(True)
                    else:
                        sort_index.append(False)

                # eliminate bad box
                for idx, box in enumerate(pred_box):
                    if round(box[0]) >= round(box[3]) or round(box[1]) >= round(box[4]) or round(box[2]) >= round(box[5]):
                        sort_index[idx] = False

                # test with mask pipeline
                net.mask_backbone.eval()
                net.mask_backbone.cuda()
                mask_pred_batch = []
                for net_i in range(1):
                    mask_pred = []
                    for pred_box_ind, pred_box_item in enumerate(pred_box):
                        if sort_index[pred_box_ind]:
                            mask_pred.append(net.mask_backbone(Variable(blobs['data'].cuda())[net_i:net_i+1, :, 
                                                                            int(round(pred_box_item[0])):int(round(pred_box_item[3])),
                                                                            int(round(pred_box_item[1])):int(round(pred_box_item[4])), 
                                                                            int(round(pred_box_item[2])):int(round(pred_box_item[5]))
                                                                            ], [] if cfg.USE_IMAGES else None))

                    mask_pred_batch.append(mask_pred)
                net._predictions['mask_pred'] = mask_pred_batch

                # save test result
                pred_mask = []
                mask_ind = 0
                for ind, cls in enumerate(pred_class):
                    if sort_index[ind]:
                        mask = net._predictions['mask_pred'][0][mask_ind][0][cls].data.cpu().numpy()
                        mask = np.where(mask >=cfg.MASK_THRESH, 1, 0).astype(np.float32)
                        pred_mask.append(mask)
                        mask_ind += 1

                pickle.dump(pred_mask, open('{}/{}/pred_mask'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), 'wb'))
                pickle.dump(sort_index, open('{}/{}/pred_mask_index'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), 'wb'))

        timer.toc()
        print('It took {:.3f}s for test on whole scenes'.format(timer.total_time()))

    @staticmethod
    def test(net, data_loader, data_logger):
        #####################################
        # Preparation
        #####################################
        os.makedirs(cfg.TEST_SAVE_DIR, exist_ok=True)
        mAP_CLASSIFICATION = Evaluate_metric(cfg.NUM_CLASSES, ignore_class=[0], overlap_threshold=cfg.MAP_THRESH)
        mAP_MASK = Evaluate_metric(cfg.NUM_CLASSES, ignore_class=[0], overlap_threshold=cfg.MAP_THRESH)

        ####################################
        # Accumulate data
        ####################################
        pred_all = {}
        gt_all = {}

        timer = Timer()
        timer.tic()
        print('starting test on whole scan....')
        for iter, blobs in enumerate(tqdm(data_loader)):

            try:
                gt_box = blobs['gt_box'][0].numpy()[:, 0:6]
                gt_class = blobs['gt_box'][0][:, 6].numpy()
            except:
                continue

            # color proj
            killing_inds = None
            if cfg.USE_IMAGES:
                grid_shape = blobs['data'].shape[-3:]
                projection_helper = ProjectionHelper(cfg.INTRINSIC, cfg.PROJ_DEPTH_MIN, cfg.PROJ_DEPTH_MAX, cfg.DEPTH_SHAPE, grid_shape, cfg.VOXEL_SIZE)
                if grid_shape[0]*grid_shape[1]*grid_shape[2] > cfg.MAX_VOLUME or blobs['nearest_images']['depths'][0].shape[0] > cfg.MAX_IMAGE:
                    proj_mapping = [projection_helper.compute_projection(d, c, t) for d, c, t in zip(blobs['nearest_images']['depths'][0], blobs['nearest_images']['poses'][0], blobs['nearest_images']['world2grid'][0])]
                else:
                    proj_mapping = [projection_helper.compute_projection(d.cuda(), c.cuda(), t.cuda()) for d, c, t in zip(blobs['nearest_images']['depths'][0], blobs['nearest_images']['poses'][0], blobs['nearest_images']['world2grid'][0])]
                    
                killing_inds = []
                real_proj_mapping = []
                if None in proj_mapping: #invalid sample
                    for killing_ind, killing_item in enumerate(proj_mapping):
                        if killing_item == None:
                            killing_inds.append(killing_ind)
                        else:
                            real_proj_mapping.append(killing_item)
                    print('{}: (invalid sample: no valid projection)'.format(blobs['id']))
                else:
                    real_proj_mapping = proj_mapping
                blobs['proj_ind_3d'] = []
                blobs['proj_ind_2d'] = []
                proj_mapping0, proj_mapping1 = zip(*real_proj_mapping)
                blobs['proj_ind_3d'].append(torch.stack(proj_mapping0))
                blobs['proj_ind_2d'].append(torch.stack(proj_mapping1))

            net.forward(blobs, 'TEST', killing_inds)

            # test with detection pipeline
            pred_class = net._predictions['cls_pred'].data.cpu().numpy()
            rois = net._predictions['rois'][0].cpu()
            box_reg_pre = net._predictions["bbox_pred"].data.cpu().numpy()
            box_reg = np.zeros((box_reg_pre.shape[0], 6))
            pred_conf_pre = net._predictions['cls_prob'].data.cpu().numpy()
            pred_conf = np.zeros((pred_conf_pre.shape[0]))

            for pred_ind in range(pred_class.shape[0]):
                box_reg[pred_ind, :] = box_reg_pre[pred_ind, pred_class[pred_ind]*6:(pred_class[pred_ind]+1)*6]
                pred_conf[pred_ind] = pred_conf_pre[pred_ind, pred_class[pred_ind]]

            pred_box = bbox_transform_inv(rois, torch.from_numpy(box_reg).float())
            pred_box = clip_boxes(pred_box, net._scene_info[:3]).numpy()

            os.makedirs('{}/{}'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), exist_ok=True)
            np.save('{}/{}/pred_class'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), pred_class)
            np.save('{}/{}/pred_conf'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), pred_conf)
            np.save('{}/{}/pred_box'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), pred_box)
            np.save('{}/{}/scene'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), np.where(blobs['data'][0,0].numpy() <= 1, 1, 0))
            np.save('{}/{}/gt_class'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), gt_class)
            np.save('{}/{}/gt_box'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), gt_box)

            # pickup
            sort_index = []
            for conf_index in range(pred_conf.shape[0]):
                if pred_conf[conf_index] > cfg.CLASS_THRESH:
                    sort_index.append(True)
                else:
                    sort_index.append(False)

            # eliminate bad box
            for idx, box in enumerate(pred_box):
                if round(box[0]) >= round(box[3]) or round(box[1]) >= round(box[4]) or round(box[2]) >= round(box[5]):
                    sort_index[idx] = False

            mAP_CLASSIFICATION.evaluate(
                    pred_box[sort_index],
                    pred_class[sort_index],
                    pred_conf[sort_index],
                    gt_box,
                    gt_class)

            if cfg.USE_MASK:
                gt_mask = blobs['gt_mask'][0]
                # pickup
                sort_index = []
                for conf_index in range(pred_conf.shape[0]):
                    if pred_conf[conf_index] > cfg.CLASS_THRESH:
                        sort_index.append(True)
                    else:
                        sort_index.append(False)

                # eliminate bad box
                for idx, box in enumerate(pred_box):
                    if round(box[0]) >= round(box[3]) or round(box[1]) >= round(box[4]) or round(box[2]) >= round(box[5]):
                        sort_index[idx] = False

                # test with mask pipeline
                net.mask_backbone.eval()
                net.mask_backbone.cuda()
                mask_pred_batch = []
                for net_i in range(1):
                    mask_pred = []
                    for pred_box_ind, pred_box_item in enumerate(pred_box):
                        if sort_index[pred_box_ind]:
                            mask_pred.append(net.mask_backbone(Variable(blobs['data'].cuda())[net_i:net_i+1, :, 
                                                                            int(round(pred_box_item[0])):int(round(pred_box_item[3])),
                                                                            int(round(pred_box_item[1])):int(round(pred_box_item[4])), 
                                                                            int(round(pred_box_item[2])):int(round(pred_box_item[5]))
                                                                            ], [] if cfg.USE_IMAGES else None))

                    mask_pred_batch.append(mask_pred)
                net._predictions['mask_pred'] = mask_pred_batch

                # save test result
                pred_mask = []
                mask_ind = 0
                for ind, cls in enumerate(pred_class):
                    if sort_index[ind]:
                        mask = net._predictions['mask_pred'][0][mask_ind][0][cls].data.cpu().numpy()
                        mask = np.where(mask >=cfg.MASK_THRESH, 1, 0).astype(np.float32)
                        pred_mask.append(mask)
                        mask_ind += 1

                pickle.dump(pred_mask, open('{}/{}/pred_mask'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), 'wb'))
                pickle.dump(sort_index, open('{}/{}/pred_mask_index'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), 'wb'))
                pickle.dump(gt_mask, open('{}/{}/gt_mask'.format(cfg.TEST_SAVE_DIR, blobs['id'][0].split('/')[-1][:12]), 'wb'))

                mAP_MASK.evaluate_mask(
                        pred_box[sort_index],
                        pred_class[sort_index],
                        pred_conf[sort_index],
                        pred_mask,
                        gt_box,
                        gt_class, 
                        gt_mask, 
                        net._scene_info)

        timer.toc()
        print('It took {:.3f}s for test on whole scenes'.format(timer.total_time()))

        ###################################
        # Summary
        ###################################
        if cfg.USE_CLASS:
            mAP_CLASSIFICATION.finalize()
            print('mAP of CLASSIFICATION: {}'.format(mAP_CLASSIFICATION.mAP()))
            for class_ind in range(cfg.NUM_CLASSES):
                if class_ind not in mAP_CLASSIFICATION.ignore_class:
                    print('class {}: {}'.format(class_ind, mAP_CLASSIFICATION.AP(class_ind)))

        if cfg.USE_MASK:
            mAP_MASK.finalize()
            print('mAP of mask: {}'.format(mAP_MASK.mAP()))
            for class_ind in range(cfg.NUM_CLASSES):
                if class_ind not in mAP_MASK.ignore_class:
                    print('class {}: {}'.format(class_ind, mAP_MASK.AP(class_ind)))
Esempio n. 15
0
def is_args_valid(arguments: argparse.Namespace, logger: Logger) -> bool:
    """ Валидирует аргументы командной строки

    :param arguments:
    """
    if not arguments.url and not arguments.raw_requests:
        logger.error('Требуется указать один из аргументов -u или -r')
        return False

    if arguments.url:
        addr = urlparse(arguments.url)

        if os.path.isfile(arguments.url) or (addr.scheme and addr.netloc):
            pass
        else:
            logger.error('Некорректный формат аргумента -u')
            return False

    if arguments.raw_requests:
        if not os.path.exists(arguments.raw_requests):
            logger.error('Указанного пути -r не существует')
            return False

    if arguments.param_wordlist:
        bad_paths = [
            path for path in re.split('\s*,\s*', arguments.param_wordlist)
            if not os.path.isfile(path)
        ]

        if bad_paths:
            logger.error(
                'Следующие пути --param-wordlists не указывают на словари: ' +
                '"' + '", "'.join(bad_paths) + '"')
            return False
    else:
        if arguments.find_all or arguments.find_params:
            logger.error(
                'Требуется указать хотя бы один словарь --param-wordlists для поиска параметров'
            )
            return False

    if arguments.header_wordlist:
        bad_paths = [
            path for path in re.split('\s*,\s*', arguments.header_wordlist)
            if not os.path.isfile(path)
        ]

        if bad_paths:
            logger.error(
                'Следующие пути --header-wordlists не указывают на словари: ' +
                '"' + '", "'.join(bad_paths) + '"')
            return False
    else:
        if arguments.find_all or arguments.find_headers:
            logger.error(
                'Требуется указать хотя бы один словарь --header-wordlists для поиска параметров'
            )
            return False

    if arguments.cookie_wordlist:
        bad_paths = [
            path for path in re.split('\s*,\s*', arguments.cookie_wordlist)
            if not os.path.isfile(path)
        ]

        if bad_paths:
            logger.error(
                'Следующие пути --cookie-wordlists не указывают на словари: ' +
                '"' + '", "'.join(bad_paths) + '"')
            return False
    else:
        if arguments.find_all or arguments.find_cookies:
            logger.error(
                'Требуется указать хотя бы один словарь --cookie-wordlists для поиска параметров'
            )
            return False

    if not (arguments.find_headers or arguments.find_params
            or arguments.find_cookies or arguments.find_all):
        logger.error(
            'Не указан тип сканирования --find-headers / --find-params / --find-cookies / --find-all'
        )
        return False

    if arguments.retry <= 0:
        logger.error(
            'Общее число попыток --retry выполнить запрос должно быть больше 0'
        )
        return False

    if arguments.timeout <= 0:
        logger.error('Время ожидания ответа --timeout должно быть больше 0')

    return True
Esempio n. 16
0
    def get_optimal_bucket(self, info: RequestInfo, min_chunk: int, add_random: Callable,
                           additional_size: Callable, logger: Logger) -> Union[int, None]:
        """ Ищет оптимальный размер порции параметров соотношение (Длина порции) / (время ответа)

        :param info:
        :return:
        """
        left, cur, right = 1024, 2048, 4096
        left_border = 0
        right_border = math.inf

        counter = 5

        optimal_size = None
        optimal_rate = 0

        # Ограничение на число циклов
        while counter:
            counter -= 1

            # Если левая граница обнулилась
            if left == 0:
                break

            # Если диапазон неделим, то прекратить цикл
            if right - cur < 2 or cur - left < 2:
                break

            # Подготавливаем запросы
            _requests = [info.copy_request() for _ in range(3)]
            for request, length in zip(_requests, [left, cur, right]):
                add_random(request, length)

            # Отправляем
            jobs = [gevent.spawn(self.do_request, request) for request in _requests]
            gevent.joinall(jobs)
            responses = [job.value for job in jobs]

            # Получаем результаты
            results = []
            # results = [response.status_code == info.response.status_code if response is not None else response
            #            for response in responses]

            for response in responses:
                if not response:
                    results.append(None)
                # Если совпадают коды ответа
                elif response.status_code == info.response.status_code:
                    results.append(True)
                # Если Payload Too Large/URI Too Long/Request Header Fields Too Large
                elif response.status_code in {413, 414, 431}:
                    results.append(False)
                # Если код ответа на отрезке  [500, 599], а оригинальный код не в этом отрезке
                elif 500 <= response.status_code < 600 and not 500 <= info.response.status_code < 600:
                    results.append(False)
                # Если код ответа на отрезке  [400, 499], а оригинальный код не в этом отрезке
                elif 400 <= response.status_code < 500 and not 400 <= info.response.status_code < 500:
                    results.append(False)
                else:
                    logger.debug(f'Необработанный случай: act_status_code={response.status_code}, orig_status_cod={info.response.status_code}')
                    results.append(True)

            # Если все запросы не получили ответа от сервера, то сдвигаемся влево
            if not any(results):
                right_border = left

                right = right_border
                cur = right >> 1
                left = cur >> 1

                continue

            # Иначе выбираем среди ответов оптимальный
            rates = []

            for response, size, result in zip([response for response in responses], [left, cur, right], results):
                # Рассматриваем только те случаи, когда мы не вышли за границы
                elapsed = response.elapsed.total_seconds() if (response is not None and result == True) else math.inf
                rate = round(size / elapsed, 1)
                rates.append(rate)

                if rate > optimal_rate and result:
                    optimal_rate = rate
                    optimal_size = size

            # Cмотрим, в какую сторону развивается динамика
            max_rate = max(rates)

            # Если все запросы не превысили границу, то двигаемся в сторону динамики
            if all(results):
                # Если динамика увеличивается слева
                if rates[0] == max_rate:
                    right_border = right

                    # То смещаемся влево
                    right = left - 1
                    cur = right >> 1
                    left = cur >> 1

                    # Если левый указатель меньше левой границы
                    if left < left_border:
                        # То пересчитываем указатели в пределах границ
                        left, cur, right = self.shift_bounds(left_border, right_border)

                # Если динамика увеличивается справа
                elif rates[2] == max_rate:
                    left_border = left

                    # То смещаемся вправо
                    left = right + 1
                    cur = left << 1
                    right = cur << 1

                    # Если правый указатель вышел за пределы правой границы
                    if right > right_border:
                        # То пересчитываем указатели в пределах границ
                        left, cur, right = self.shift_bounds(left_border, right_border)

                # Иначе рассматриваем окрестности центра
                else:
                    left_border = left if left > left_border else left_border
                    right_border = right if right < right_border else right_border

                    left = (left + cur) // 2
                    right = (cur + right) // 2
            # Если результаты [True, False|None, False|None]
            elif results[0] == True and all([not r for r in results[1:]]):
                right_border = cur if cur < right_border else right_border
                # То сдвигаемся влево
                right = left - 1
                cur = right >> 1
                left = cur >> 1
            # Если результаты [True, True, False|None]
            elif results[2] in {None, False} and all([r for r in results[:2]]):
                right_border = right if right < right_border else right_border
                # То смотрим на динамику слева и посередине

                # Если динамика увеличивается слева
                if rates[0] == max_rate:
                    # То сдвигаемся влево
                    right = left - 1  # Сдвигаем рассматриваемую правую границу на 1 от ранее рассматриваемой левой
                    cur = right >> 1
                    left = cur >> 1

                    # Если левый указатель меньше левой границы
                    if left < left_border:
                        # То пересчитываем указатели в пределах границ
                        left, cur, right = self.shift_bounds(left_border, right_border)
                # Иначе копаем в пределах cur
                else:
                    right = round((cur + right) / 2)
                    left = (left + cur) // 2
            else:
                # Сдвигаемся влево
                right = left - 1  # Сдвигаем рассматриваемую правую границу на 1 от ранее рассматриваемой левой
                cur = right >> 1
                left = cur >> 1

        # Если по итогу оптимальный размер меньше минимально требуемого, то вернуть минимально требуемый требуемый
        if optimal_size is not None:
            if optimal_size < min_chunk < right_border:
                return min_chunk + additional_size(info)

            return optimal_size + additional_size(info)

        return optimal_size
Esempio n. 17
0
 def setUp(self):
     self.logger = Logger(log_dir='logs/unittest')
     self.model = nn.Conv2d(3, 32, 3, 1)
Esempio n. 18
0
 def __init__(self, experiment):
     self.experiment = experiment
     self.device = torch.device(
         'cuda' if torch.cuda.is_available() else 'cpu')
     self.logger = Logger(self.experiment.stats_folder)