예제 #1
0
def train_proxy(opt):
    logging.info(opt)

    # Set random seed
    mx.random.seed(opt.seed)
    np.random.seed(opt.seed)

    # Setup computation context
    context = get_context(opt.gpus, logging)

    run_results = []

    # Adjust batch size to each compute context
    batch_size = opt.batch_size * len(context)

    # Prepare feature extractor
    if opt.model == 'inception-bn':
        feature_net, feature_params = get_feature_model(opt.model, ctx=context)
        data_shape = 224
        scale_image_data = False
    elif opt.model == 'resnet50_v2':
        feature_net = mx.gluon.model_zoo.vision.resnet50_v2(
            pretrained=True, ctx=context).features
        data_shape = 224
        scale_image_data = True
    else:
        raise RuntimeError('Unsupported model: %s' % opt.model)

    # Prepare datasets
    train_dataset, val_dataset = get_dataset(opt.dataset,
                                             opt.data_path,
                                             data_shape=data_shape,
                                             use_crops=opt.use_crops,
                                             use_aug=True,
                                             with_proxy=True,
                                             scale_image_data=scale_image_data)
    logging.info('Training with %d classes, validating with %d classes' %
                 (train_dataset.num_classes(), val_dataset.num_classes()))

    if opt.iteration_per_epoch > 0:
        train_dataset, _ = get_dataset_iterator(
            opt.dataset,
            opt.data_path,
            batch_k=(opt.batch_size //
                     3) if opt.loss == 'xentropy' else opt.batch_k,
            batch_size=opt.batch_size,
            data_shape=data_shape,
            use_crops=opt.use_crops,
            scale_image_data=scale_image_data,
            batchify=False)
        train_dataloader = mx.gluon.data.DataLoader(
            DatasetIterator(train_dataset,
                            opt.iteration_per_epoch,
                            'next_proxy_sample',
                            call_params={
                                'sampled_classes':
                                (opt.batch_size // opt.batch_k) if
                                (opt.batch_k is not None) else None,
                                'chose_classes_randomly':
                                True,
                            }),
            batch_size=1,
            shuffle=False,
            num_workers=opt.num_workers,
            last_batch='keep')
    else:
        train_dataloader = mx.gluon.data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=opt.num_workers,
            last_batch='rollover')
    val_dataloader = mx.gluon.data.DataLoader(val_dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=opt.num_workers,
                                              last_batch='keep')

    # Prepare proxy model
    net = ProxyNet(feature_net,
                   opt.embed_dim,
                   num_classes=train_dataset.num_classes())

    if opt.lr is None:
        logging.info('Using variable learning rate')
        opt.lr = max([opt.lr_proxynca, opt.lr_embedding, opt.lr_inception])

        for p, v in net.encoder.collect_params().items():
            v.lr_mult = opt.lr_embedding / opt.lr

        for p, v in net.base_net.collect_params().items():
            v.lr_mult = opt.lr_inception / opt.lr

        for p, v in net.proxies.collect_params().items():
            v.lr_mult = opt.lr_proxynca / opt.lr
    else:
        logging.info('Using single learning rate: %f' % opt.lr)

    for run in range(1, opt.number_of_runs + 1):
        logging.info('Starting run %d/%d' % (run, opt.number_of_runs))

        # reset networks
        if opt.model == 'inception-bn':
            net.base_net.collect_params().load(feature_params,
                                               ctx=context,
                                               ignore_extra=True)

            if opt.dataset == 'CUB':
                for v in net.base_net.collect_params().values():
                    if v.name in ['batchnorm', 'bn_']:
                        v.grad_req = 'null'

        elif opt.model == 'resnet50_v2':
            logging.info('Lowering LR for Resnet backbone')
            net.base_net = mx.gluon.model_zoo.vision.resnet50_v2(
                pretrained=True, ctx=context).features

            # Use a smaller learning rate for pre-trained convolutional layers.
            for v in net.base_net.collect_params().values():
                if 'conv' in v.name:
                    setattr(v, 'lr_mult', 0.01)
        else:
            raise NotImplementedError('Unknown model: %s' % opt.model)

        if opt.loss == 'triplet':
            net.encoder.initialize(mx.init.Xavier(magnitude=0.2),
                                   ctx=context,
                                   force_reinit=True)
            net.proxies.initialize(mx.init.Xavier(magnitude=0.2),
                                   ctx=context,
                                   force_reinit=True)
        else:
            net.init(TruncNorm(stdev=0.001), ctx=context, init_basenet=False)
        if not opt.disable_hybridize:
            net.hybridize()

        run_result = train(net, opt, train_dataloader, val_dataloader, context,
                           run)
        run_results.append(run_result)
        logging.info('Run %d finished with %f' % (run, run_result[0][1]))

    logging.info(
        'Average validation of %d runs:\n%s' %
        (opt.number_of_runs, format_results(average_results(run_results))))
예제 #2
0
def train_discriminative(opt):
    logging.info(opt)

    # Settings.
    mx.random.seed(opt.seed)
    np.random.seed(opt.seed)

    # Setup computation context
    context = get_context(opt.gpus, logging)

    run_results = []

    # Get model
    if opt.model == 'inception-bn':
        feature_net, feature_params = get_feature_model(opt.model, ctx=context)
        feature_net.collect_params().load(feature_params,
                                          ctx=context,
                                          ignore_extra=True)
        data_shape = 224
        scale_image_data = False
    elif opt.model == 'resnet50_v2':
        feature_net = mx.gluon.model_zoo.vision.resnet50_v2(
            pretrained=True, ctx=context).features
        data_shape = 224
        scale_image_data = True
    else:
        raise RuntimeError('Unsupported model: %s' % opt.model)

    # Get data iterators
    train_dataset, val_dataset = get_dataset(opt.dataset,
                                             opt.data_path,
                                             data_shape=data_shape,
                                             use_crops=opt.use_crops,
                                             use_aug=True,
                                             with_proxy=True,
                                             scale_image_data=scale_image_data)
    logging.info('Training with %d classes, validating with %d classes' %
                 (train_dataset.num_classes(), val_dataset.num_classes()))
    train_dataloader = mx.gluon.data.DataLoader(train_dataset,
                                                batch_size=opt.batch_size,
                                                shuffle=True,
                                                num_workers=opt.num_workers,
                                                last_batch='rollover')
    val_dataloader = mx.gluon.data.DataLoader(val_dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=opt.num_workers,
                                              last_batch='keep')

    net = EmbeddingNet(
        feature_net,
        [opt.embed_dim, train_dataset.num_classes()],
        normalize=True,
        dropout=True)

    if opt.model == 'resnet50_v2':
        # Use a smaller learning rate for pre-trained convolutional layers.
        for v in net.base_net.collect_params().values():
            if 'conv' in v.name:
                setattr(v, 'lr_mult', 0.01)

    # main run loop for multiple training runs
    for run in range(1, opt.number_of_runs + 1):
        logging.info('Starting run %d/%d' % (run, opt.number_of_runs))

        net.init(mx.init.Xavier(magnitude=0.2),
                 ctx=context,
                 init_basenet=False)

        if opt.model == 'inception-bn':
            net.base_net.collect_params().load(feature_params,
                                               ctx=context,
                                               ignore_extra=True)
        elif opt.model == 'resnet50_v2':
            net.base_net = mx.gluon.model_zoo.vision.resnet50_v2(
                pretrained=True, ctx=context).features
        else:
            raise RuntimeError('Unsupported model: %s' % opt.model)
        if not opt.disable_hybridize:
            net.hybridize()

        run_result = train(net, opt, train_dataloader, val_dataloader, context,
                           run)
        run_results.append(run_result)
        logging.info('Run %d finished with %f' % (run, run_result[0][1]))

    logging.info(
        'Average validation of %d runs:\n%s' %
        (opt.number_of_runs, format_results(average_results(run_results))))
예제 #3
0
def train_dreml(opt):
    logging.info(opt)

    # Set random seed
    mx.random.seed(opt.seed)
    np.random.seed(opt.seed)

    # Setup computation context
    context = get_context(opt.gpus, logging)
    cpu_ctx = mx.cpu()

    # Adjust batch size to each compute context
    batch_size = opt.batch_size * len(context)

    if opt.model == 'inception-bn':
        scale_image_data = False
    elif opt.model in ['resnet50_v2', 'resnet18_v2']:
        scale_image_data = True
    else:
        raise RuntimeError('Unsupported model: %s' % opt.model)

    # Prepare datasets
    train_dataset, val_dataset = get_dataset(opt.dataset, opt.data_path, data_shape=opt.data_shape, use_crops=opt.use_crops,
                                             use_aug=True, with_proxy=True, scale_image_data=scale_image_data,
                                             resize_img=int(opt.data_shape * 1.1))

    # Create class mapping
    mapping = np.random.randint(0, opt.D, (opt.L, train_dataset.num_classes()))

    # Train embedding functions one by one
    trained_models = []
    best_results = []  # R@1, NMI
    for ens in tqdm(range(opt.L), desc='Training model in ensemble'):
        train_dataset.set_class_mapping(mapping[ens], opt.D)
        train_dataloader = mx.gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                                    num_workers=opt.num_workers, last_batch='rollover')

        if opt.model == 'inception-bn':
            feature_net, feature_params = get_feature_model(opt.model, ctx=context)
        elif opt.model == 'resnet50_v2':
            feature_net = mx.gluon.model_zoo.vision.resnet50_v2(pretrained=True, ctx=context).features
        elif opt.model == 'resnet18_v2':
            feature_net = mx.gluon.model_zoo.vision.resnet18_v2(pretrained=True, ctx=context).features
        else:
            raise RuntimeError('Unsupported model: %s' % opt.model)

        if opt.static_proxies:
            net = EmbeddingNet(feature_net, opt.D, normalize=False)
        else:
            net = ProxyNet(feature_net, opt.D, num_classes=opt.D)

        # Init loss function
        if opt.static_proxies:
            logging.info('Using static proxies')
            proxyloss = StaticProxyLoss(opt.D)
        elif opt.loss == 'nca':
            logging.info('Using NCA loss')
            proxyloss = ProxyNCALoss(opt.D, exclude_positives=True, label_smooth=opt.label_smooth,
                                     multiplier=opt.embedding_multiplier)
        elif opt.loss == 'triplet':
            logging.info('Using triplet loss')
            proxyloss = ProxyTripletLoss(opt.D)
        elif opt.loss == 'xentropy':
            logging.info('Using NCA loss without excluding positives')
            proxyloss = ProxyNCALoss(opt.D, exclude_positives=False, label_smooth=opt.label_smooth,
                                     multiplier=opt.embedding_multiplier)
        else:
            raise RuntimeError('Unknown loss function: %s' % opt.loss)

        # Init optimizer
        opt_options = {'learning_rate': opt.lr, 'wd': opt.wd}
        if opt.optimizer == 'sgd':
            opt_options['momentum'] = 0.9
        elif opt.optimizer == 'adam':
            opt_options['epsilon'] = opt.epsilon
        elif opt.optimizer == 'rmsprop':
            opt_options['gamma1'] = 0.9
            opt_options['epsilon'] = opt.epsilon

        # Calculate decay steps
        steps = parse_steps(opt.steps, opt.epochs, logger=logging)

        # reset networks
        if opt.model == 'inception-bn':
            net.base_net.collect_params().load(feature_params, ctx=context, ignore_extra=True)
        elif opt.model in ['resnet18_v2', 'resnet50_v2']:
            net.base_net = mx.gluon.model_zoo.vision.get_model(opt.model, pretrained=True, ctx=context).features
        else:
            raise NotImplementedError('Unknown model: %s' % opt.model)

        if opt.static_proxies:
            net.init(mx.init.Xavier(magnitude=0.2), ctx=context, init_basenet=False)
        elif opt.loss == 'triplet':
            net.encoder.initialize(mx.init.Xavier(magnitude=0.2), ctx=context, force_reinit=True)
            net.proxies.initialize(mx.init.Xavier(magnitude=0.2), ctx=context, force_reinit=True)
        else:
            net.init(TruncNorm(stdev=0.001), ctx=context, init_basenet=False)
        if not opt.disable_hybridize:
            net.hybridize()

        trainer = mx.gluon.Trainer(net.collect_params(), opt.optimizer,
                                   opt_options,
                                   kvstore=opt.kvstore)

        smoothing_constant = .01  # for tracking moving losses
        moving_loss = 0

        for epoch in range(1, opt.epochs + 1):
            p_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader),
                         desc=('[Model %d/%d] Epoch %d' % (ens + 1, opt.L, epoch)))

            new_lr = get_lr(opt.lr, epoch, steps, opt.factor)
            logging.info('Setting LR to %f' % new_lr)
            trainer.set_learning_rate(new_lr)

            for i, batch in p_bar:
                data = mx.gluon.utils.split_and_load(batch[0], ctx_list=context, batch_axis=0, even_split=False)
                label = mx.gluon.utils.split_and_load(batch[1], ctx_list=context, batch_axis=0, even_split=False)
                negative_labels = mx.gluon.utils.split_and_load(batch[2], ctx_list=context, batch_axis=0,
                                                                even_split=False)

                with ag.record():
                    losses = []
                    for x, y, nl in zip(data, label, negative_labels):
                        if opt.static_proxies:
                            embs = net(x)
                            losses.append(proxyloss(embs, y))
                        else:
                            embs, positive_proxy, negative_proxies, proxies = net(x, y, nl)
                            if opt.loss in ['nca', 'xentropy']:
                                losses.append(proxyloss(embs, proxies, y, nl))
                            else:
                                losses.append(proxyloss(embs, positive_proxy, negative_proxies))
                for l in losses:
                    l.backward()

                trainer.step(data[0].shape[0])

                ##########################
                #  Keep a moving average of the losses
                ##########################
                curr_loss = mx.nd.mean(mx.nd.maximum(mx.nd.concatenate(losses), 0)).asscalar()
                moving_loss = (curr_loss if ((i == 0) and (epoch == 1))  # starting value
                               else (1 - smoothing_constant) * moving_loss + smoothing_constant * curr_loss)
                p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss)

            logging.info('Moving loss: %.4f' % moving_loss)

        # move model to CPU
        mx.nd.waitall()
        net.collect_params().reset_ctx(cpu_ctx)
        trained_models.append(net)
        del train_dataloader

        # Run ensemble validation
        logging.info('Running validation with %d models in the ensemble' % len(trained_models))
        val_dataloader = mx.gluon.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                                                  num_workers=opt.num_workers, last_batch='keep')

        validation_results = validate(val_dataloader, trained_models, context, opt.static_proxies)

        for name, val_acc in validation_results:
            logging.info('Validation: %s=%f' % (name, val_acc))

        if (len(best_results) == 0) or (validation_results[0][1] > best_results[0][1]):
            best_results = validation_results
            logging.info('New best validation: R@1: %f NMI: %f' % (best_results[0][1], best_results[-1][1]))
예제 #4
0
def train_normproxy(opt):
    logging.info(opt)

    # Set random seed
    mx.random.seed(opt.seed)
    np.random.seed(opt.seed)

    # Setup computation context
    context = get_context(opt.gpus, logging)

    # Adjust batch size to each compute context
    batch_size = opt.batch_size * len(context)

    run_results = []

    # Prepare feature extractor
    if opt.model == 'inception-bn':
        feature_net, feature_params = get_feature_model(opt.model, ctx=context)
        feature_net.collect_params().load(feature_params, ctx=context, ignore_extra=True)
        data_shape = 224
        scale_image_data = False
        feature_size = 1024
    elif opt.model == 'resnet50_v2':
        feature_params = None
        feature_net = mx.gluon.model_zoo.vision.resnet50_v2(pretrained=True, ctx=context).features
        data_shape = 224
        scale_image_data = True
        feature_size = 2048
    else:
        raise RuntimeError('Unsupported model: %s' % opt.model)

    # Prepare datasets
    train_dataset, val_dataset = get_dataset(opt.dataset, opt.data_path, data_shape=data_shape, use_crops=opt.use_crops,
                                             use_aug=True, with_proxy=True, scale_image_data=scale_image_data)
    logging.info(
        'Training with %d classes, validating with %d classes' % (
            train_dataset.num_classes(), val_dataset.num_classes()))

    if opt.batch_k > 0:
        train_dataset, _ = get_dataset_iterator(opt.dataset, opt.data_path, batch_k=opt.batch_k,
                                                batch_size=batch_size, data_shape=data_shape, use_crops=opt.use_crops,
                                                scale_image_data=scale_image_data)
        train_dataloader = train_dataset
    else:
        train_dataloader = mx.gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                                    num_workers=opt.num_workers, last_batch='rollover')
    val_dataloader = mx.gluon.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                                              num_workers=opt.num_workers, last_batch='keep')
    num_train_classes = train_dataset.num_classes()

    # Prepare proxy model
    net = NormProxyNet(feature_net, opt.embed_dim, num_classes=train_dataset.num_classes(),
                       feature_size=feature_size, no_fc=opt.no_fc, dropout=opt.dropout, static_proxies=opt.static_proxies)

    # main run loop for multiple training runs
    for run in range(1, opt.number_of_runs + 1):
        logging.info('Starting run %d/%d' % (run, opt.number_of_runs))

        # reset networks
        if opt.model == 'inception-bn':
            net.base_net.collect_params().load(feature_params, ctx=context, ignore_extra=True)
        elif opt.model == 'resnet50_v2':
            net.base_net = mx.gluon.model_zoo.vision.resnet50_v2(pretrained=True, ctx=context).features

            # Use a smaller learning rate for pre-trained convolutional layers.
            logging.info('Lowering LR for Resnet backbone by 100x')
            for v in net.base_net.collect_params().values():
                if 'conv' in v.name:
                    setattr(v, 'lr_mult', 0.01)
        else:
            raise NotImplementedError('Unknown model: %s' % opt.model)

        if opt.start_epoch != 1:
            param_file = 'normproxy_model.params'
            logging.info('Loading parameters from %s' % param_file)
            net.load_parameters(param_file, ctx=context)
        else:
            if opt.model == 'resnet50_v2':
                net.init(mx.init.Xavier(magnitude=2), ctx=context, init_basenet=False)
            else:
                net.init(TruncNorm(stdev=0.001), ctx=context, init_basenet=False)
        if not opt.disable_hybridize:
            net.hybridize()

        run_result = train(net, opt, train_dataloader, val_dataloader, num_train_classes, context, run)
        run_results.append(run_result)
        logging.info('Run %d finished with %f' % (run, run_result[0][1]))

    logging.info(
        'Average validation of %d runs:\n%s' % (opt.number_of_runs, format_results(average_results(run_results))))
예제 #5
0
def train_angular(opt):
    logging.info(opt)

    # Settings.
    mx.random.seed(opt.seed)
    np.random.seed(opt.seed)

    # Setup computation context
    context = get_context(opt.gpus, logging)

    run_results = []

    # Get model
    if opt.model == 'inception-bn':
        feature_net, feature_params = get_feature_model(opt.model, ctx=context)
        feature_net.collect_params().load(feature_params,
                                          ctx=context,
                                          ignore_extra=True)
        data_shape = 224
        scale_image_data = False
    elif opt.model == 'resnet50_v2':
        feature_net = mx.gluon.model_zoo.vision.resnet50_v2(
            pretrained=True, ctx=context).features
        feature_params = mx.gluon.model_zoo.model_store.get_model_file(
            'resnet%d_v%d' % (50, 2),
            root=os.path.join(mx.base.data_dir(), 'models'))
        data_shape = 224
        scale_image_data = True
    else:
        raise RuntimeError('Unsupported model: %s' % opt.model)

    net = EmbeddingNet(feature_net, opt.embed_dim, normalize=False)

    if opt.model == 'resnet50_v2':
        # Use a smaller learning rate for pre-trained convolutional layers.
        for v in net.base_net.collect_params().values():
            if 'conv' in v.name:
                setattr(v, 'lr_mult', 0.01)

    # Get iterators
    train_data, _ = get_npairs_iterators(
        opt.dataset,
        opt.data_path,
        batch_size=opt.batch_size,
        data_shape=data_shape,
        test_batch_size=len(context) * 32,
        use_crops=opt.use_crops,
        scale_image_data=scale_image_data,
        same_image_sampling=opt.same_image_sampling)

    train_it_dataloader = mx.gluon.data.DataLoader(DatasetIterator(
        train_data, opt.epoch_length),
                                                   batch_size=1,
                                                   shuffle=False,
                                                   num_workers=opt.num_workers,
                                                   last_batch='rollover')

    _, val_dataset = get_dataset(opt.dataset,
                                 opt.data_path,
                                 data_shape=data_shape,
                                 use_crops=opt.use_crops,
                                 use_aug=True,
                                 with_proxy=False,
                                 scale_image_data=scale_image_data)
    val_data = mx.gluon.data.DataLoader(val_dataset,
                                        batch_size=opt.batch_size,
                                        shuffle=False,
                                        num_workers=opt.num_workers,
                                        last_batch='keep')

    for run in range(1, opt.number_of_runs + 1):
        logging.info('Starting run %d/%d' % (run, opt.number_of_runs))

        net.init(mx.init.Xavier(magnitude=0.2),
                 ctx=context,
                 init_basenet=False)
        if opt.model == 'inception-bn':
            net.base_net.collect_params().load(feature_params,
                                               ctx=context,
                                               ignore_extra=True)
        elif opt.model == 'resnet50_v2':
            net.base_net.load_parameters(feature_params,
                                         ctx=context,
                                         allow_missing=True,
                                         ignore_extra=True)
        if not opt.disable_hybridize:
            net.hybridize()

        run_result = train(net, opt, train_it_dataloader, val_data, context,
                           run)
        run_results.append(run_result)
        logging.info('Run %d finished with %f' % (run, run_result[0][1]))

    logging.info(
        'Average validation of %d runs:\n%s' %
        (opt.number_of_runs, format_results(average_results(run_results))))
예제 #6
0
def train_rankedlist(opt):
    logging.info(opt)

    # Settings.
    mx.random.seed(opt.seed)
    np.random.seed(opt.seed)

    # Setup computation context
    context = get_context(opt.gpus, logging)

    run_results = []

    # Get model
    if opt.model == 'inception-bn':
        feature_net, feature_params = get_feature_model(opt.model, ctx=context)
        feature_net.collect_params().load(feature_params,
                                          ctx=context,
                                          ignore_extra=True)
        data_shape = 224
        scale_image_data = False
    elif opt.model == 'resnet50_v2':
        feature_net = mx.gluon.model_zoo.vision.resnet50_v2(
            pretrained=True, ctx=context).features
        data_shape = 224
        scale_image_data = True
        feature_params = None
    else:
        raise RuntimeError('Unsupported model: %s' % opt.model)

    if opt.bottleneck_layers != '':
        embedding_layers = [int(x) for x in opt.bottleneck_layers.split(',')
                            ] + [opt.embed_dim]
    else:
        embedding_layers = [opt.embed_dim]
    logging.info('Embedding layers: [%s]' %
                 ','.join([str(x) for x in embedding_layers]))
    if len(embedding_layers) == 1:
        embedding_layers = embedding_layers[0]

    net = EmbeddingNet(feature_net,
                       embedding_layers,
                       normalize=True,
                       dropout=False)
    logging.info(net)

    if opt.model == 'resnet50_v2':
        # Use a smaller learning rate for pre-trained convolutional layers.
        for v in net.base_net.collect_params().values():
            if 'conv' in v.name:
                setattr(v, 'lr_mult', 0.01)
            elif 'batchnorm' in v.name or 'bn_' in v.name:
                v.grad_req = 'null'
    else:
        for v in net.encoder.collect_params().values():
            setattr(v, 'lr_mult', 10.)

    # Get data iterators
    train_dataset = DatasetIterator(
        get_dataset_iterator(opt.dataset,
                             opt.data_path,
                             batch_k=opt.batch_k,
                             batch_size=opt.batch_size,
                             batchify=False,
                             data_shape=data_shape,
                             use_crops=opt.use_crops,
                             scale_image_data=scale_image_data)[0],
        opt.iteration_per_epoch, 'next')

    train_dataiterator = mx.gluon.data.DataLoader(train_dataset,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=opt.num_workers,
                                                  last_batch='keep')

    val_dataset = get_dataset(opt.dataset,
                              opt.data_path,
                              data_shape=data_shape,
                              use_crops=opt.use_crops,
                              use_aug=True,
                              scale_image_data=scale_image_data)[1]
    logging.info(
        'Training with %d classes, validating with %d classes' %
        (train_dataset.data_iterator.num_classes(), val_dataset.num_classes()))

    val_dataloader = mx.gluon.data.DataLoader(val_dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=opt.num_workers,
                                              last_batch='keep')

    # main run loop for multiple training runs
    for run in range(1, opt.number_of_runs + 1):
        logging.info('Starting run %d/%d' % (run, opt.number_of_runs))

        net.init(mx.init.Xavier(magnitude=0.2),
                 ctx=context,
                 init_basenet=False)

        if opt.model == 'inception-bn':
            net.base_net.collect_params().load(feature_params,
                                               ctx=context,
                                               ignore_extra=True)
        elif opt.model == 'resnet50_v2':
            net.base_net = mx.gluon.model_zoo.vision.resnet50_v2(
                pretrained=True, ctx=context).features
        else:
            raise RuntimeError('Unsupported model: %s' % opt.model)
        if not opt.disable_hybridize:
            net.hybridize()

        run_result = train(net, opt, train_dataiterator, val_dataloader,
                           context, run)
        run_results.append(run_result)
        logging.info('Run %d finished with %f' % (run, run_result[0][1]))

    logging.info(
        'Average validation of %d runs:\n%s' %
        (opt.number_of_runs, format_results(average_results(run_results))))
예제 #7
0
def train_margin(opt):
    logging.info(opt)

    # Set random seed
    mx.random.seed(opt.seed)
    np.random.seed(opt.seed)

    # Setup computation context
    context = get_context(opt.gpus, logging)

    # Adjust batch size to each compute context
    batch_size = opt.batch_size * len(context)

    run_results = []

    # Get model
    if opt.model == 'inception-bn':
        feature_net, feature_params = get_feature_model(opt.model, ctx=context)
        feature_net.collect_params().load(feature_params,
                                          ctx=context,
                                          ignore_extra=True)
        data_shape = 224
        scale_image_data = False
    elif opt.model == 'resnet50_v2':
        feature_params = None
        feature_net = mx.gluon.model_zoo.vision.resnet50_v2(
            pretrained=True, ctx=context).features
        data_shape = 224
        scale_image_data = True
    else:
        raise RuntimeError('Unsupported model: %s' % opt.model)

    net = MarginNet(feature_net, opt.embed_dim)

    if opt.model == 'resnet50_v2':
        # Use a smaller learning rate for pre-trained convolutional layers.
        for v in net.base_net.collect_params().values():
            if 'conv' in v.name:
                setattr(v, 'lr_mult', 0.01)

    # Get data iterators
    train_dataset, val_dataset = get_dataset(opt.dataset,
                                             opt.data_path,
                                             data_shape=data_shape,
                                             use_crops=opt.use_crops,
                                             use_aug=True,
                                             scale_image_data=scale_image_data)
    train_dataiter, _ = get_dataset_iterator(opt.dataset,
                                             opt.data_path,
                                             batch_k=opt.batch_k,
                                             batch_size=batch_size,
                                             data_shape=data_shape,
                                             use_crops=opt.use_crops,
                                             scale_image_data=scale_image_data,
                                             batchify=False)
    train_dataloader = mx.gluon.data.DataLoader(DatasetIterator(
        train_dataiter, opt.iteration_per_epoch, 'next'),
                                                batch_size=1,
                                                shuffle=False,
                                                num_workers=opt.num_workers,
                                                last_batch='keep')
    val_dataloader = mx.gluon.data.DataLoader(val_dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=opt.num_workers,
                                              last_batch='keep')

    logging.info('Training with %d classes, validating with %d classes' %
                 (train_dataset.num_classes(), val_dataset.num_classes()))

    # main run loop for multiple training runs
    for run in range(1, opt.number_of_runs + 1):
        logging.info('Starting run %d/%d' % (run, opt.number_of_runs))

        # Re-init embedding layers and reload pretrained layers
        if opt.model == 'inception-bn':
            net.init(mx.init.Xavier(magnitude=0.2),
                     ctx=context,
                     init_basenet=False)
            net.base_net.collect_params().load(feature_params,
                                               ctx=context,
                                               ignore_extra=True)
        elif opt.model == 'resnet50_v2':
            net.init(mx.init.Xavier(magnitude=2),
                     ctx=context,
                     init_basenet=False)
            net.base_net = mx.gluon.model_zoo.vision.resnet50_v2(
                pretrained=True, ctx=context).features
        else:
            raise RuntimeError('Unknown model type: %s' % opt.model)

        if not opt.disable_hybridize:
            net.hybridize()

        if opt.lr_beta > 0.0:
            logging.info('Learning beta margin')
            beta = mx.gluon.nn.Embedding(train_dataset.num_classes(), 1)
        else:
            beta = opt.beta

        run_result = train(net, beta, opt, train_dataloader, val_dataloader,
                           batch_size, context, run)
        run_results.append(run_result)
        logging.info('Run %d finished with %f' % (run, run_result[0][1]))

    logging.info(
        'Average validation of %d runs:\n%s' %
        (opt.number_of_runs, format_results(average_results(run_results))))