Exemplo n.º 1
0
 def __init__(self, seq, buffer, dataset_path, model_path, set_type='train',
              max_res=800, branch_arch='alexnet', ctx_mode='max'):
     """
     Args:
         seq: (int) The number of the sequence according to the get_sequence
             function, which mirrors the indexing of the ImageNetVID class.
         buffer: (queue.Queue) The data buffer between the producerThread and
             the consumer application (the display). The elements stored in
             this buffer are defined by the BufferElement namedtuple.
         dataset_path: (string) The path to the root of the ImageNet dataset.
         model_path: (string) The path to the models .pth.tar file containing
             the model's weights.
         set_type: (string) The subset of the ImageNet VID dataset, can be
             'train' or 'val'.
         max_res: (int) The maximum resolution in pixels. If any dimension
             of the image exceeds this value, the final image published by
             the producer is resized (keeping the aspect ratio). Used to
             balance the load between the consumer (main) thread and the
             producer.
         branch_arch: (string) The architecture of the branch of the siamese
             net. Might be: 'alexnet', 'vgg11_5c'.
         ctx_mode: (string) The strategy used to define the context region
             around the target, using the bounding box dimensions. The 'max'
             mode uses the biggest dimension, while the 'mean' mode uses the
             mean of the dimensions.
     """
     super(ProducerThread, self).__init__(daemon=True)
     self.frames, self.bboxes_norm, self.valid_frames, self.vid_dims = (
         get_sequence(seq, dataset_path, set_type=set_type))
     self.idx = 0
     self.seq_size = len(self.frames)
     self.buffer = buffer
     # TODO put the model info inside the checkpoint file.
     if branch_arch == 'alexnet':
         self.net = mdl.SiameseNet(mdl.BaselineEmbeddingNet(), stride=4)
     elif branch_arch == 'vgg11_5c':
         self.net = mdl.SiameseNet(mdl.VGG11EmbeddingNet_5c(), stride=4)
     elif branch_arch == "vgg16_8c":
         self.net = mdl.SiameseNet(mdl.VGG16EmbeddingNet_8c(), stride=4)
     checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
     self.net.load_state_dict(checkpoint['state_dict'])
     # Tuple of (H, w), the dimensions to which the image will be resized.
     self.resize_dims = None
     self.net = self.net.to(device)
     self.net.eval()
     self.ref, self.ref_emb = self.make_ref(ctx_mode=ctx_mode)
Exemplo n.º 2
0
def main(args):
    """ Execute this file as a script and it will save the baseline model in
    the default experiment folder.
    """
    net_path = args.net_file_path
    net = mdl.SiameseNet(mdl.BaselineEmbeddingNet())
    net = load_baseline(net_path, net)
    optimizer = optim.Adam(net.parameters(), lr=1e-3)
    export_to_checkpoint({'epoch': 0,
                          'state_dict': net.state_dict(),
                          'optim_dict': optimizer.state_dict()},
                         checkpoint=args.dst_path)
Exemplo n.º 3
0
def main(args):
    root_dir = dirname(abspath(__file__))
    # Load the parameters from json file
    imagenet_dir = args.data_dir
    exp_dir = join(root_dir, 'training', 'experiments', args.exp_name)
    json_path = join(exp_dir, 'parameters.json')
    assert isfile(json_path), (
        "No json configuration file found at {}".format(json_path))
    params = train_utils.Params(json_path)
    # Add the timer option to the parameters
    params.update_with_dict({'timer': args.timer})
    params.update_with_dict({'num_workers': args.num_workers})

    train_utils.set_logger(join(exp_dir, '{}.log'.format(args.mode)))
    logging.info("----Starting train script in mode: {}----".format(args.mode))

    setup_timer = Timer(convert=True)
    setup_timer.reset()
    logging.info("Loading datasets...")

    # Get the correct model
    if params.model == 'BaselineEmbeddingNet':
        model = mdl.SiameseNet(mdl.BaselineEmbeddingNet(),
                               upscale=params.upscale,
                               corr_map_size=33,
                               stride=4)
    elif params.model == 'VGG11EmbeddingNet_5c':
        model = mdl.SiameseNet(mdl.VGG11EmbeddingNet_5c(),
                               upscale=params.upscale,
                               corr_map_size=33,
                               stride=4)
    elif params.model == 'VGG16EmbeddingNet_8c':
        model = mdl.SiameseNet(mdl.VGG16EmbeddingNet_8c(),
                               upscale=params.upscale,
                               corr_map_size=33,
                               stride=4)

    # Freeze all the indicated parameters
    for i, (name, parameter) in enumerate(model.named_parameters()):
        if i in params.parameter_freeze:
            logging.info("Freezing parameter {}".format(name))
            parameter.requires_grad = False

    model = model.to(device)
    # Set the tensorboard summary maker
    summ_maker = SummaryMaker(join(exp_dir, 'tensorboard'), params,
                              model.upscale_factor)

    label_function = create_BCELogit_loss_label
    img_read_fcn = imutils.get_decode_jpeg_fcn(flag=args.imutils_flag)
    img_resize_fcn = imutils.get_resize_fcn(flag=args.imutils_flag)

    logging.info("Validation dataset...")

    metadata_val_file = join(exp_dir, "metadata.val")
    val_set = ImageNetVID_val(imagenet_dir,
                              label_fcn=label_function,
                              pos_thr=params.pos_thr,
                              neg_thr=params.neg_thr,
                              upscale_factor=model.upscale_factor,
                              cxt_margin=params.context_margin,
                              reference_size=params.reference_sz,
                              search_size=params.search_sz,
                              img_read_fcn=img_read_fcn,
                              resize_fcn=img_resize_fcn,
                              metadata_file=metadata_val_file,
                              save_metadata=metadata_val_file,
                              max_frame_sep=params.max_frame_sep)
    val_loader = DataLoader(val_set,
                            batch_size=params.batch_size,
                            shuffle=False,
                            num_workers=params.num_workers,
                            pin_memory=True)
    if params.eval_epoch_size > len(val_loader):
        logging.info('The user set eval_epoch_size ({}) is bigger than the '
                     'size of the eval set ({}). \n Setting '
                     'eval_epoch_size to the eval set size.'.format(
                         params.eval_epoch_size, len(val_loader)))
        params.eval_epoch_size = len(val_loader)

    # Define the model and optimizer

    # fetch loss function and metrics
    loss_fn = losses.BCELogit_Loss
    metrics = met.METRICS
    # Set the optional keyword arguments for the functions that need it
    metrics['center_error']['kwargs']['upscale_factor'] = model.upscale_factor

    try:
        if args.mode == 'train':

            logging.info("Training dataset...")

            metadata_train_file = join(exp_dir, "metadata.train")
            train_set = ImageNetVID(imagenet_dir,
                                    label_fcn=label_function,
                                    pos_thr=params.pos_thr,
                                    neg_thr=params.neg_thr,
                                    upscale_factor=model.upscale_factor,
                                    cxt_margin=params.context_margin,
                                    reference_size=params.reference_sz,
                                    search_size=params.search_sz,
                                    img_read_fcn=img_read_fcn,
                                    resize_fcn=img_resize_fcn,
                                    metadata_file=metadata_train_file,
                                    save_metadata=metadata_train_file,
                                    max_frame_sep=params.max_frame_sep)
            train_loader = DataLoader(train_set,
                                      batch_size=params.batch_size,
                                      shuffle=True,
                                      num_workers=params.num_workers,
                                      pin_memory=True)

            # Though I'm not a big fan of changing the value of a parameter
            # variable after it has been read, at least I let the user know I'm
            # changing it.
            if params.train_epoch_size > len(train_loader):
                logging.info(
                    'The user set train_epoch_size ({}) is bigger than the '
                    'size of the train set ({}). \n Setting '
                    'train_epoch_size to the train set size.'.format(
                        params.train_epoch_size, len(train_loader)))
                params.train_epoch_size = len(train_loader)

            logging.info("Done")
            logging.info("Setup time: {}".format(setup_timer.elapsed))
            parameters = filter(lambda p: p.requires_grad, model.parameters())
            optimizer = optimz.OPTIMIZERS[params.optim](parameters,
                                                        **params.optim_kwargs)
            # Set the scheduler, that updates the learning rate using a exponential
            # decay. If you don't want lr decay set it to 1.
            logging.info("Using Exponential Learning Rate Decay of {}".format(
                params.lr_decay))
            scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer, params.lr_decay)

            logging.info("Epoch sizes: {} in train and {} in eval".format(
                params.train_epoch_size, params.eval_epoch_size))

            logging.info("Starting training for {} epoch(s)".format(
                params.num_epochs))
            with Timer(convert=True) as t:
                train_and_evaluate(model,
                                   train_loader,
                                   val_loader,
                                   optimizer,
                                   scheduler,
                                   loss_fn,
                                   metrics,
                                   params,
                                   exp_dir,
                                   args,
                                   summ_maker=summ_maker)
            if params.timer:
                logging.info(
                    "[profiling] Total time to train {} epochs, with {}"
                    " elements on training dataset and {} "
                    "on val dataset: {}".format(params.num_epochs,
                                                len(train_loader),
                                                len(val_loader), t.elapsed))

        elif args.mode == 'eval':
            logging.info("Done")
            with Timer(convert=True) as total:
                logging.info("Starting evaluation")
                # TODO write a decent Exception
                if args.restore_file is None:
                    raise IncompleteArgument("In eval mode you have to specify"
                                             " a model checkpoint to be loaded"
                                             " and evaluated."
                                             " E.g: --restore_file best")
                checkpoint_path = join(exp_dir, args.restore_file + '.pth.tar')
                train_utils.load_checkpoint(checkpoint_path, model)
                # Evaluate
                summ_maker.epoch = 0
                test_metrics = evaluate(model,
                                        loss_fn,
                                        val_loader,
                                        metrics,
                                        params,
                                        args,
                                        summ_maker=summ_maker)
                save_path = join(
                    exp_dir, "metrics_test_{}.json".format(args.restore_file))
                train_utils.save_dict_to_json(test_metrics, save_path)
            if params.timer:
                logging.info("[profiling] Total evaluation time: {}".format(
                    total.elapsed))

    except KeyboardInterrupt:
        logging.info("=== User interrupted execution ===")
        raise
    except Exception as e:
        logging.exception("Fatal error in main loop")
        logging.info("=== Execution Terminated with error ===")
    else:
        logging.info("=== Execution exited normally ===")