def train(net, opt, train_dataloader, val_dataloader, context, run_id):
    """Training function."""
    if not opt.skip_pretrain_validation:
        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0)
        for name, val_acc in validation_results:
            logging.info('Pre-train validation: %s=%f' % (name, val_acc))

    steps = parse_steps(opt.steps, opt.epochs, logging)

    opt_options = {'learning_rate': opt.lr, 'wd': opt.wd, 'clip_gradient': 10.}
    if opt.optimizer == 'sgd':
        opt_options['momentum'] = 0.9
    if opt.optimizer == 'adam':
        opt_options['epsilon'] = 1e-7
    trainer = mx.gluon.Trainer(net.collect_params(),
                               opt.optimizer,
                               opt_options,
                               kvstore=opt.kvstore)

    L = DiscriminativeLoss(train_dataloader._dataset.num_classes(),
                           len(train_dataloader._dataset))
    L.initialize(ctx=context)
    if not opt.disable_hybridize:
        L.hybridize()

    smoothing_constant = .01  # for tracking moving losses
    moving_loss = 0
    best_results = []  # R@1, NMI

    for epoch in range(1, opt.epochs + 1):
        p_bar = tqdm(enumerate(train_dataloader),
                     total=len(train_dataloader),
                     desc=('[Run %d/%d] Epoch %d' %
                           (run_id, opt.number_of_runs, epoch)))
        trainer.set_learning_rate(get_lr(opt.lr, epoch, steps, opt.factor))

        for i, batch in p_bar:
            data = mx.gluon.utils.split_and_load(batch[0],
                                                 ctx_list=context,
                                                 batch_axis=0,
                                                 even_split=False)
            label = mx.gluon.utils.split_and_load(batch[1],
                                                  ctx_list=context,
                                                  batch_axis=0,
                                                  even_split=False)
            negative_labels = mx.gluon.utils.split_and_load(batch[2],
                                                            ctx_list=context,
                                                            batch_axis=0,
                                                            even_split=False)

            with ag.record():
                losses = []
                for x, y, nl in zip(data, label, negative_labels):
                    embs = net(x)
                    losses.append(L(embs, y, nl))
            for l in losses:
                l.backward()

            trainer.step(len(losses))

            #  Keep a moving average of the losses
            curr_loss = mx.nd.mean(mx.nd.concatenate(losses)).asscalar()
            moving_loss = (
                curr_loss if ((i == 0) and (epoch == 1))  # starting value
                else (1 - smoothing_constant) * moving_loss +
                smoothing_constant * curr_loss)  # add current
            p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss)

        logging.info('Moving loss: %.4f' % moving_loss)
        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0)
        for name, val_acc in validation_results:
            logging.info('[Epoch %d] validation: %s=%f' %
                         (epoch, name, val_acc))

        if (len(best_results)
                == 0) or (validation_results[0][1] > best_results[0][1]):
            best_results = validation_results
            if opt.save_model_prefix.lower() != 'none':
                filename = '%s.params' % opt.save_model_prefix
                logging.info('Saving %s.' % filename)
                net.save_parameters(filename)
            logging.info('New best validation: R@1: %f NMI: %f' %
                         (best_results[0][1], best_results[-1][1]))

    return best_results
Exemple #2
0
def train(net, opt, train_dataloader, val_dataloader, context, run_id):
    """Training function."""

    if not opt.skip_pretrain_validation:
        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0)
        for name, val_acc in validation_results:
            logging.info('Pre-train validation: %s=%f' % (name, val_acc))

    steps = parse_steps(opt.steps, opt.epochs, logging)

    opt_options = {'learning_rate': opt.lr, 'wd': opt.wd}
    if opt.optimizer == 'sgd':
        opt_options['momentum'] = 0.9
    if opt.optimizer == 'adam':
        opt_options['epsilon'] = 1e-7
    trainer = mx.gluon.Trainer(net.collect_params(),
                               opt.optimizer,
                               opt_options,
                               kvstore=opt.kvstore)

    L = PrototypeLoss(opt.nc, opt.ns, opt.nq)

    data_size = opt.nc * (opt.ns + opt.nq)

    best_results = []  # R@1, NMI
    for epoch in range(1, opt.epochs + 1):

        prev_loss, cumulative_loss = 0.0, 0.0

        trainer.set_learning_rate(get_lr(opt.lr, epoch, steps, opt.factor))
        logging.info('Epoch %d learning rate=%f', epoch, trainer.learning_rate)

        p_bar = tqdm(train_dataloader,
                     desc=('[Run %d/%d] Epoch %d' %
                           (run_id, opt.number_of_runs, epoch)))
        for batch in p_bar:
            supports_batch, queries_batch, labels_batch = [x[0] for x in batch]
            # supports_batch: <Nc x Ns x I>
            # queries_batch: <Nc x Nq x I>
            # labels_batch: <Nc x 1>

            supports_batch = mx.nd.reshape(supports_batch, (-1, 0, 0, 0),
                                           reverse=True)  # <(Nc * Ns) x I>
            queries_batch = mx.nd.reshape(queries_batch, (-1, 0, 0, 0),
                                          reverse=True)

            queries = mx.gluon.utils.split_and_load(queries_batch,
                                                    ctx_list=context,
                                                    batch_axis=0)
            supports = mx.gluon.utils.split_and_load(supports_batch,
                                                     ctx_list=context,
                                                     batch_axis=0)

            support_embs = []
            queries_embs = []
            with ag.record():
                for s in supports:
                    s_emb = net(s)
                    support_embs.append(s_emb)
                supports = mx.nd.concat(*support_embs, dim=0)  # <Nc*Ns x E>

                for q in queries:
                    q_emb = net(q)
                    queries_embs.append(q_emb)
                queries = mx.nd.concat(*queries_embs, dim=0)  # <Nc*Nq x E>

                loss = L(supports, queries)

            loss.backward()
            cumulative_loss += mx.nd.mean(loss).asscalar()
            trainer.step(data_size)

            p_bar.set_postfix({'loss': cumulative_loss - prev_loss})
            prev_loss = cumulative_loss

        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0)
        for name, val_acc in validation_results:
            logging.info('[Epoch %d] validation: %s=%f' %
                         (epoch, name, val_acc))

        if (len(best_results)
                == 0) or (validation_results[0][1] > best_results[0][1]):
            best_results = validation_results
            if opt.save_model_prefix.lower() != 'none':
                filename = '%s.params' % opt.save_model_prefix
                logging.info('Saving %s.' % filename)
                net.save_parameters(filename)
            logging.info('New best validation: R@1: %f NMI: %f' %
                         (best_results[0][1], best_results[-1][1]))

    return best_results
Exemple #3
0
def train(net, opt, train_dataloader, val_dataloader, context, run_id):
    """Training function."""

    if not opt.skip_pretrain_validation:
        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0)
        for name, val_acc in validation_results:
            logging.info('Pre-train validation: %s=%f' % (name, val_acc))

    # Calculate decay steps
    steps = parse_steps(opt.steps, opt.epochs, logger=logging)

    # Init optimizer
    opt_options = {'learning_rate': opt.lr, 'wd': opt.wd, 'clip_gradient': 10.}
    if opt.optimizer == 'sgd':
        opt_options['momentum'] = 0.9
    elif opt.optimizer in ['adam', 'radam']:
        opt_options['epsilon'] = opt.epsilon
    elif opt.optimizer == 'rmsprop':
        opt_options['gamma1'] = 0.9
        opt_options['epsilon'] = opt.epsilon

    trainer = mx.gluon.Trainer(net.collect_params(),
                               opt.optimizer,
                               opt_options,
                               kvstore=opt.kvstore)

    # Init loss function
    if opt.loss == 'nca':
        logging.info('Using NCA loss')
        proxyloss = ProxyNCALoss(train_dataloader._dataset.num_classes(),
                                 exclude_positives=True,
                                 label_smooth=opt.label_smooth,
                                 multiplier=opt.embedding_multiplier,
                                 temperature=opt.temperature)
    elif opt.loss == 'triplet':
        logging.info('Using triplet loss')
        proxyloss = ProxyTripletLoss(train_dataloader._dataset.num_classes())
    elif opt.loss == 'xentropy':
        logging.info('Using NCA loss without excluding positives')
        proxyloss = ProxyNCALoss(train_dataloader._dataset.num_classes(),
                                 exclude_positives=False,
                                 label_smooth=opt.label_smooth,
                                 multiplier=opt.embedding_multiplier,
                                 temperature=opt.temperature)
    else:
        raise RuntimeError('Unknown loss function: %s' % opt.loss)

    smoothing_constant = .01  # for tracking moving losses
    moving_loss = 0
    best_results = []  # R@1, NMI

    for epoch in range(1, opt.epochs + 1):
        p_bar = tqdm(enumerate(train_dataloader),
                     total=len(train_dataloader),
                     desc=('[Run %d/%d] Epoch %d' %
                           (run_id, opt.number_of_runs, epoch)))

        new_lr = get_lr(opt.lr, epoch, steps, opt.factor)
        logging.info('Setting LR to %f' % new_lr)
        trainer.set_learning_rate(new_lr)
        if opt.optimizer == 'rmsprop':
            # exponential decay of gamma
            if epoch != 1:
                trainer._optimizer.gamma1 *= .94
                logging.info('Setting rmsprop gamma to %f' %
                             trainer._optimizer.gamma1)

        for (i, batch) in p_bar:
            if opt.iteration_per_epoch > 0:
                for b in range(len(batch)):
                    batch[b] = batch[b][0]
            data = mx.gluon.utils.split_and_load(batch[0],
                                                 ctx_list=context,
                                                 batch_axis=0,
                                                 even_split=False)
            label = mx.gluon.utils.split_and_load(batch[1],
                                                  ctx_list=context,
                                                  batch_axis=0,
                                                  even_split=False)
            negative_labels = mx.gluon.utils.split_and_load(batch[2],
                                                            ctx_list=context,
                                                            batch_axis=0,
                                                            even_split=False)

            with ag.record():
                losses = []
                for x, y, nl in zip(data, label, negative_labels):
                    embs, positive_proxy, negative_proxies, proxies = net(
                        x, y, nl)
                    if opt.loss in ['nca', 'xentropy']:
                        losses.append(proxyloss(embs, proxies, y, nl))
                    else:
                        losses.append(
                            proxyloss(embs, positive_proxy, negative_proxies))
            for l in losses:
                l.backward()

            trainer.step(data[0].shape[0])

            #  Keep a moving average of the losses
            curr_loss = mx.nd.mean(mx.nd.maximum(mx.nd.concatenate(losses),
                                                 0)).asscalar()
            moving_loss = (
                curr_loss if ((i == 0) and (epoch == 1))  # starting value
                else (1 - smoothing_constant) * moving_loss +
                smoothing_constant * curr_loss)
            p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss)

        logging.info('Moving loss: %.4f' % moving_loss)
        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0)
        for name, val_acc in validation_results:
            logging.info('[Epoch %d] validation: %s=%f' %
                         (epoch, name, val_acc))

        if (len(best_results)
                == 0) or (validation_results[0][1] > best_results[0][1]):
            best_results = validation_results
            if opt.save_model_prefix.lower() != 'none':
                filename = '%s.params' % opt.save_model_prefix
                logging.info('Saving %s.' % filename)
                net.save_parameters(filename)
            logging.info('New best validation: R@1: %f NMI: %f' %
                         (best_results[0][1], best_results[-1][1]))

    return best_results
def train(net, opt, train_data, val_data, num_train_classes, context, run_id):
    """Training function"""

    if not opt.skip_pretrain_validation:
        validation_results = validate(net, val_data, context, binarize=opt.binarize, nmi=opt.nmi, similarity=opt.similarity)
        for name, val_acc in validation_results:
            logging.info('Pre-train validation: %s=%f' % (name, val_acc))

    # Calculate decay steps
    steps = parse_steps(opt.steps, opt.epochs, logger=logging)

    # Init optimizer
    opt_options = {'learning_rate': opt.lr, 'wd': opt.wd, 'clip_gradient': 10.}
    if opt.optimizer == 'sgd':
        opt_options['momentum'] = 0.9
    elif opt.optimizer == 'adam':
        opt_options['epsilon'] = opt.epsilon
    elif opt.optimizer == 'rmsprop':
        opt_options['gamma1'] = 0.9
        opt_options['epsilon'] = opt.epsilon

    # We train only embedding and proxies initially
    params2train = net.encoder.collect_params()
    if not opt.static_proxies:
        params2train.update(net.proxies.collect_params())

    trainer = mx.gluon.Trainer(params2train, opt.optimizer, opt_options, kvstore=opt.kvstore)

    smoothing_constant = .01  # for tracking moving losses
    moving_loss = 0
    best_results = []  # R@1, NMI

    batch_size = opt.batch_size * len(context)

    proxyloss = ProxyXentropyLoss(num_train_classes, label_smooth=opt.label_smooth, temperature=opt.temperature)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        if epoch == 2:
            # switch training to all parameters
            logging.info('Switching to train all parameters')
            trainer = mx.gluon.Trainer(net.collect_params(), opt.optimizer, opt_options, kvstore=opt.kvstore)
        if opt.batch_k > 0:
            iterations_per_epoch = int(ceil(train_data.num_training_images() / batch_size))
            p_bar = tqdm(range(iterations_per_epoch), desc='[Run %d/%d] Epoch %d' % (run_id, opt.number_of_runs, epoch),
                         total=iterations_per_epoch)
        else:
            p_bar = tqdm(enumerate(train_data), total=len(train_data),
                         desc=('[Run %d/%d] Epoch %d' % (run_id, opt.number_of_runs, epoch)))

        new_lr = get_lr(opt.lr, epoch, steps, opt.factor)
        logging.info('Setting LR to %f' % new_lr)
        trainer.set_learning_rate(new_lr)
        if opt.optimizer == 'rmsprop':
            # exponential decay of gamma
            if epoch != 1:
                trainer._optimizer.gamma1 *= .94
                logging.info('Setting rmsprop gamma to %f' % trainer._optimizer.gamma1)

        losses = []
        curr_losses_np = []

        for i in p_bar:
            if opt.batch_k > 0:
                num_sampled_classes = batch_size // opt.batch_k
                batch = train_data.next_proxy_sample(sampled_classes=num_sampled_classes, chose_classes_randomly=True).data
            else:
                batch = i[1]
                i = i[0]

            data = mx.gluon.utils.split_and_load(batch[0], ctx_list=context, batch_axis=0, even_split=False)
            label = mx.gluon.utils.split_and_load(batch[1], ctx_list=context, batch_axis=0, even_split=False)

            with ag.record():
                for x, y in zip(data, label):
                    embs, proxies = net(x)
                    curr_loss = proxyloss(embs, proxies, y)
                    losses.append(curr_loss)
                mx.nd.waitall()

            curr_losses_np += [cl.asnumpy() for cl in losses]

            ag.backward(losses)

            trainer.step(batch[0].shape[0])

            #  Keep a moving average of the losses
            curr_loss = np.mean(np.maximum(np.concatenate(curr_losses_np), 0))
            curr_losses_np.clear()

            losses.clear()
            moving_loss = (curr_loss if ((i == 0) and (epoch == 1))  # starting value
                           else (1 - smoothing_constant) * moving_loss + smoothing_constant * curr_loss)
            p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss)

        logging.info('Moving loss: %.4f' % moving_loss)
        validation_results = validate(net, val_data, context, binarize=opt.binarize, nmi=opt.nmi, similarity=opt.similarity)
        for name, val_acc in validation_results:
            logging.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc))

        if (len(best_results) == 0) or (validation_results[0][1] > best_results[0][1]):
            best_results = validation_results
            filename = '%s.params' % opt.save_model_prefix
            logging.info('Saving %s.' % filename)
            net.save_parameters(filename)
            logging.info('New best validation: R@1: %f%s' % (best_results[0][1], (' NMI: %f' % best_results[-1][1]) if opt.nmi else ''))

    return best_results
Exemple #5
0
def train_dreml(opt):
    logging.info(opt)

    # Set random seed
    mx.random.seed(opt.seed)
    np.random.seed(opt.seed)

    # Setup computation context
    context = get_context(opt.gpus, logging)
    cpu_ctx = mx.cpu()

    # Adjust batch size to each compute context
    batch_size = opt.batch_size * len(context)

    if opt.model == 'inception-bn':
        scale_image_data = False
    elif opt.model in ['resnet50_v2', 'resnet18_v2']:
        scale_image_data = True
    else:
        raise RuntimeError('Unsupported model: %s' % opt.model)

    # Prepare datasets
    train_dataset, val_dataset = get_dataset(opt.dataset, opt.data_path, data_shape=opt.data_shape, use_crops=opt.use_crops,
                                             use_aug=True, with_proxy=True, scale_image_data=scale_image_data,
                                             resize_img=int(opt.data_shape * 1.1))

    # Create class mapping
    mapping = np.random.randint(0, opt.D, (opt.L, train_dataset.num_classes()))

    # Train embedding functions one by one
    trained_models = []
    best_results = []  # R@1, NMI
    for ens in tqdm(range(opt.L), desc='Training model in ensemble'):
        train_dataset.set_class_mapping(mapping[ens], opt.D)
        train_dataloader = mx.gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                                    num_workers=opt.num_workers, last_batch='rollover')

        if opt.model == 'inception-bn':
            feature_net, feature_params = get_feature_model(opt.model, ctx=context)
        elif opt.model == 'resnet50_v2':
            feature_net = mx.gluon.model_zoo.vision.resnet50_v2(pretrained=True, ctx=context).features
        elif opt.model == 'resnet18_v2':
            feature_net = mx.gluon.model_zoo.vision.resnet18_v2(pretrained=True, ctx=context).features
        else:
            raise RuntimeError('Unsupported model: %s' % opt.model)

        if opt.static_proxies:
            net = EmbeddingNet(feature_net, opt.D, normalize=False)
        else:
            net = ProxyNet(feature_net, opt.D, num_classes=opt.D)

        # Init loss function
        if opt.static_proxies:
            logging.info('Using static proxies')
            proxyloss = StaticProxyLoss(opt.D)
        elif opt.loss == 'nca':
            logging.info('Using NCA loss')
            proxyloss = ProxyNCALoss(opt.D, exclude_positives=True, label_smooth=opt.label_smooth,
                                     multiplier=opt.embedding_multiplier)
        elif opt.loss == 'triplet':
            logging.info('Using triplet loss')
            proxyloss = ProxyTripletLoss(opt.D)
        elif opt.loss == 'xentropy':
            logging.info('Using NCA loss without excluding positives')
            proxyloss = ProxyNCALoss(opt.D, exclude_positives=False, label_smooth=opt.label_smooth,
                                     multiplier=opt.embedding_multiplier)
        else:
            raise RuntimeError('Unknown loss function: %s' % opt.loss)

        # Init optimizer
        opt_options = {'learning_rate': opt.lr, 'wd': opt.wd}
        if opt.optimizer == 'sgd':
            opt_options['momentum'] = 0.9
        elif opt.optimizer == 'adam':
            opt_options['epsilon'] = opt.epsilon
        elif opt.optimizer == 'rmsprop':
            opt_options['gamma1'] = 0.9
            opt_options['epsilon'] = opt.epsilon

        # Calculate decay steps
        steps = parse_steps(opt.steps, opt.epochs, logger=logging)

        # reset networks
        if opt.model == 'inception-bn':
            net.base_net.collect_params().load(feature_params, ctx=context, ignore_extra=True)
        elif opt.model in ['resnet18_v2', 'resnet50_v2']:
            net.base_net = mx.gluon.model_zoo.vision.get_model(opt.model, pretrained=True, ctx=context).features
        else:
            raise NotImplementedError('Unknown model: %s' % opt.model)

        if opt.static_proxies:
            net.init(mx.init.Xavier(magnitude=0.2), ctx=context, init_basenet=False)
        elif opt.loss == 'triplet':
            net.encoder.initialize(mx.init.Xavier(magnitude=0.2), ctx=context, force_reinit=True)
            net.proxies.initialize(mx.init.Xavier(magnitude=0.2), ctx=context, force_reinit=True)
        else:
            net.init(TruncNorm(stdev=0.001), ctx=context, init_basenet=False)
        if not opt.disable_hybridize:
            net.hybridize()

        trainer = mx.gluon.Trainer(net.collect_params(), opt.optimizer,
                                   opt_options,
                                   kvstore=opt.kvstore)

        smoothing_constant = .01  # for tracking moving losses
        moving_loss = 0

        for epoch in range(1, opt.epochs + 1):
            p_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader),
                         desc=('[Model %d/%d] Epoch %d' % (ens + 1, opt.L, epoch)))

            new_lr = get_lr(opt.lr, epoch, steps, opt.factor)
            logging.info('Setting LR to %f' % new_lr)
            trainer.set_learning_rate(new_lr)

            for i, batch in p_bar:
                data = mx.gluon.utils.split_and_load(batch[0], ctx_list=context, batch_axis=0, even_split=False)
                label = mx.gluon.utils.split_and_load(batch[1], ctx_list=context, batch_axis=0, even_split=False)
                negative_labels = mx.gluon.utils.split_and_load(batch[2], ctx_list=context, batch_axis=0,
                                                                even_split=False)

                with ag.record():
                    losses = []
                    for x, y, nl in zip(data, label, negative_labels):
                        if opt.static_proxies:
                            embs = net(x)
                            losses.append(proxyloss(embs, y))
                        else:
                            embs, positive_proxy, negative_proxies, proxies = net(x, y, nl)
                            if opt.loss in ['nca', 'xentropy']:
                                losses.append(proxyloss(embs, proxies, y, nl))
                            else:
                                losses.append(proxyloss(embs, positive_proxy, negative_proxies))
                for l in losses:
                    l.backward()

                trainer.step(data[0].shape[0])

                ##########################
                #  Keep a moving average of the losses
                ##########################
                curr_loss = mx.nd.mean(mx.nd.maximum(mx.nd.concatenate(losses), 0)).asscalar()
                moving_loss = (curr_loss if ((i == 0) and (epoch == 1))  # starting value
                               else (1 - smoothing_constant) * moving_loss + smoothing_constant * curr_loss)
                p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss)

            logging.info('Moving loss: %.4f' % moving_loss)

        # move model to CPU
        mx.nd.waitall()
        net.collect_params().reset_ctx(cpu_ctx)
        trained_models.append(net)
        del train_dataloader

        # Run ensemble validation
        logging.info('Running validation with %d models in the ensemble' % len(trained_models))
        val_dataloader = mx.gluon.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                                                  num_workers=opt.num_workers, last_batch='keep')

        validation_results = validate(val_dataloader, trained_models, context, opt.static_proxies)

        for name, val_acc in validation_results:
            logging.info('Validation: %s=%f' % (name, val_acc))

        if (len(best_results) == 0) or (validation_results[0][1] > best_results[0][1]):
            best_results = validation_results
            logging.info('New best validation: R@1: %f NMI: %f' % (best_results[0][1], best_results[-1][1]))
Exemple #6
0
def train(net, opt, train_dataloader, val_dataloader, context, run_id):
    """Training function."""

    if not opt.skip_pretrain_validation:
        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0)
        for name, val_acc in validation_results:
            logging.info('Pre-train validation: %s=%f' % (name, val_acc))

    steps = parse_steps(opt.steps, opt.epochs, logging)

    opt_options = {'learning_rate': opt.lr, 'wd': opt.wd, 'clip_gradient': 10.}
    if opt.optimizer == 'sgd':
        opt_options['momentum'] = 0.9
    if opt.optimizer == 'adam':
        opt_options['epsilon'] = 1e-7

    if opt.decrease_cnn_lr:
        logging.info('Setting embedding LR to %f' % (10.0 * opt.lr))
        for p, v in net.encoder.collect_params().items():
            v.lr_mult = 10.0

    trainer = mx.gluon.Trainer(net.collect_params(),
                               opt.optimizer,
                               opt_options,
                               kvstore=opt.kvstore)

    if opt.angular_lambda > 0:
        # Use NPair and Angular loss together, l2 regularization is 0 for angular in this case
        L = AngluarLoss(alpha=np.deg2rad(opt.alpha),
                        l2_reg=0,
                        symmetric=opt.symmetric_loss)
        L2 = NPairsLoss(l2_reg=opt.l2reg_weight, symmetric=opt.symmetric_loss)
        if not opt.disable_hybridize:
            L2.hybridize()
    else:
        L = AngluarLoss(alpha=np.deg2rad(opt.alpha),
                        l2_reg=opt.l2reg_weight,
                        symmetric=opt.symmetric_loss)
    if not opt.disable_hybridize:
        L.hybridize()

    best_results = []  # R@1, NMI

    for epoch in range(1, opt.epochs + 1):
        prev_loss, cumulative_loss = 0.0, 0.0
        # Learning rate schedule.
        trainer.set_learning_rate(get_lr(opt.lr, epoch, steps, opt.factor))
        logging.info('Epoch %d learning rate=%f', epoch, trainer.learning_rate)

        p_bar = tqdm(train_dataloader,
                     desc=('[Run %d/%d] Epoch %d' %
                           (run_id, opt.number_of_runs, epoch)))
        for batch in p_bar:
            anchors_batch = batch[0][0]  # <N x I>
            positives_batch = batch[1][0]  # <N x I>

            anchors = mx.gluon.utils.split_and_load(anchors_batch,
                                                    ctx_list=context,
                                                    batch_axis=0)
            positives = mx.gluon.utils.split_and_load(positives_batch,
                                                      ctx_list=context,
                                                      batch_axis=0)
            labels_batch = mx.gluon.utils.split_and_load(batch[2][0],
                                                         ctx_list=context,
                                                         batch_axis=0)
            anchor_embs = []
            positives_embs = []

            with ag.record():
                for a, p in zip(anchors, positives):
                    a_emb = net(a)
                    p_emb = net(p)
                    anchor_embs.append(a_emb)
                    positives_embs.append(p_emb)
                anchors = mx.nd.concat(*anchor_embs, dim=0)
                positives = mx.nd.concat(*positives_embs, dim=0)

                if opt.angular_lambda > 0:
                    angular_loss = L(anchors, positives, labels_batch[0])
                    npairs_loss = L2(anchors, positives, labels_batch[0])
                    loss = npairs_loss + (opt.angular_lambda * angular_loss)
                else:
                    loss = L(anchors, positives, labels_batch[0])

            loss.backward()
            cumulative_loss += mx.nd.mean(loss).asscalar()
            trainer.step(opt.batch_size)

            p_bar.set_postfix({'loss': cumulative_loss - prev_loss})
            prev_loss = cumulative_loss

        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0)
        for name, val_acc in validation_results:
            logging.info('[Epoch %d] validation: %s=%f' %
                         (epoch, name, val_acc))

        if (len(best_results)
                == 0) or (validation_results[0][1] > best_results[0][1]):
            best_results = validation_results
            if opt.save_model_prefix.lower() != 'none':
                filename = '%s.params' % opt.save_model_prefix
                logging.info('Saving %s.' % filename)
                net.save_parameters(filename)
            logging.info('New best validation: R@1: %f NMI: %f' %
                         (best_results[0][1], best_results[-1][1]))

    return best_results
Exemple #7
0
def train(net, opt, train_dataloader, val_dataloader, context, run_id):
    """Training function."""
    if not opt.skip_pretrain_validation:
        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0,
                                      nmi=opt.nmi)
        for name, val_acc in validation_results:
            logging.info('Pre-train validation: %s=%f' % (name, val_acc))

    steps = parse_steps(opt.steps, opt.epochs, logging)

    opt_options = {
        'learning_rate': opt.lr,
        'wd': opt.wd,
    }
    if opt.optimizer == 'sgd':
        opt_options['momentum'] = 0.9
    if opt.optimizer == 'adam':
        opt_options['epsilon'] = 1e-7
    trainer = mx.gluon.Trainer(net.collect_params(),
                               opt.optimizer,
                               opt_options,
                               kvstore=opt.kvstore)

    L = RankedListLoss(margin=opt.margin,
                       alpha=opt.alpha,
                       temperature=opt.temperature)
    if not opt.disable_hybridize:
        L.hybridize()

    smoothing_constant = .01  # for tracking moving losses
    moving_loss = 0
    best_results = []  # R@1, NMI

    for epoch in range(1, opt.epochs + 1):
        p_bar = tqdm(enumerate(train_dataloader),
                     desc='[Run %d/%d] Epoch %d' %
                     (run_id, opt.number_of_runs, epoch),
                     total=len(train_dataloader))
        trainer.set_learning_rate(get_lr(opt.lr, epoch, steps, opt.factor))

        for i, (data, labels) in p_bar:
            data = data[0].as_in_context(context[0])
            labels = labels[0].astype('int32').as_in_context(context[0])

            with ag.record():
                losses = []
                embs = net(data)
                losses.append(L(embs, labels))
            for l in losses:
                l.backward()

            trainer.step(1)

            #  Keep a moving average of the losses
            curr_loss = mx.nd.mean(mx.nd.concatenate(losses)).asscalar()
            moving_loss = (
                curr_loss if ((i == 0) and (epoch == 1))  # starting value
                else (1 - smoothing_constant) * moving_loss +
                smoothing_constant * curr_loss)  # add current
            p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss)

        logging.info('Moving loss: %.4f' % moving_loss)
        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0,
                                      nmi=opt.nmi)
        for name, val_acc in validation_results:
            logging.info('[Epoch %d] validation: %s=%f' %
                         (epoch, name, val_acc))

        if (len(best_results)
                == 0) or (validation_results[0][1] > best_results[0][1]):
            best_results = validation_results
            if opt.save_model_prefix.lower() != 'none':
                filename = '%s.params' % opt.save_model_prefix
                logging.info('Saving %s.' % filename)
                net.save_parameters(filename)
            logging.info('New best validation: R@1: %f NMI: %f' %
                         (best_results[0][1], best_results[-1][1]))

    return best_results
def train(net, opt, train_dataloader, val_dataloader, context, run_id):
    """Training function."""

    if not opt.skip_pretrain_validation:
        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0)
        for name, val_acc in validation_results:
            logging.info('Pre-train validation: %s=%f' % (name, val_acc))

    steps = parse_steps(opt.steps, opt.epochs, logging)

    opt_options = {'learning_rate': opt.lr, 'wd': opt.wd}
    if opt.optimizer == 'sgd':
        opt_options['momentum'] = 0.9
    elif opt.optimizer == 'adam':
        opt_options['epsilon'] = opt.epsilon
    elif opt.optimizer == 'rmsprop':
        opt_options['gamma1'] = 0.94
        opt_options['epsilon'] = opt.epsilon

    if opt.decrease_cnn_lr:
        logging.info('Setting embedding LR to %f' % (10.0 * opt.lr))
        for p, v in net.encoder.collect_params().items():
            v.lr_mult = 10.0

    trainer = mx.gluon.Trainer(net.collect_params(),
                               opt.optimizer,
                               opt_options,
                               kvstore=opt.kvstore)

    L = ClusterLoss(num_classes=train_dataloader._dataset.num_classes()
                    )  # Not hybridizable

    smoothing_constant = .01  # for tracking moving losses
    moving_loss = 0
    best_results = []  # R@1, NMI

    for epoch in range(1, opt.epochs + 1):
        p_bar = tqdm(enumerate(train_dataloader),
                     total=len(train_dataloader),
                     desc=('[Run %d/%d] Epoch %d' %
                           (run_id, opt.number_of_runs, epoch)))
        trainer.set_learning_rate(get_lr(opt.lr, epoch, steps, opt.factor))

        if opt.optimizer == 'rmsprop':
            # exponential decay of gamme
            if epoch != 1:
                trainer._optimizer.gamma1 *= .94

        for i, (data, labels) in p_bar:
            if opt.iteration_per_epoch > 0:
                data = data[0]
                labels = labels[0]
            labels = labels.astype('int32', copy=False)
            unique_labels = unique(mx.nd, labels).astype('float32')

            # extract label stats
            num_classes_batch = []
            if len(context) == 1:
                num_classes_batch.append(
                    mx.nd.array([unique_labels.size], dtype='int32'))
            else:
                slices = mx.gluon.utils.split_data(labels,
                                                   len(context),
                                                   batch_axis=0,
                                                   even_split=False)
                for s in slices:
                    num_classes_batch.append(
                        mx.nd.array([np.unique(s.asnumpy()).size],
                                    dtype='int32'))

            data = mx.gluon.utils.split_and_load(data,
                                                 ctx_list=context,
                                                 batch_axis=0,
                                                 even_split=False)
            label = mx.gluon.utils.split_and_load(labels,
                                                  ctx_list=context,
                                                  batch_axis=0,
                                                  even_split=False)

            unique_labels = mx.gluon.utils.split_and_load(unique_labels,
                                                          ctx_list=context,
                                                          batch_axis=0,
                                                          even_split=False)

            with ag.record():
                losses = []
                for x, y, uy, nc in zip(data, label, unique_labels,
                                        num_classes_batch):
                    embs = net(x)
                    losses.append(
                        L(
                            embs, y.astype('float32', copy=False), uy,
                            mx.nd.arange(start=0,
                                         stop=x.shape[0],
                                         ctx=y.context)))
            for l in losses:
                l.backward()

            trainer.step(data[0].shape[0])

            #  Keep a moving average of the losses
            curr_loss = mx.nd.mean(mx.nd.concatenate(losses)).asscalar()
            moving_loss = (
                curr_loss if ((i == 0) and (epoch == 1))  # starting value
                else (1 - smoothing_constant) * moving_loss +
                smoothing_constant * curr_loss)
            p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss)

        logging.info('Moving loss: %.4f' % moving_loss)
        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0)
        for name, val_acc in validation_results:
            logging.info('[Epoch %d] validation: %s=%f' %
                         (epoch, name, val_acc))

        if (len(best_results)
                == 0) or (validation_results[0][1] > best_results[0][1]):
            best_results = validation_results
            if opt.save_model_prefix.lower() != 'none':
                filename = '%s.params' % opt.save_model_prefix
                logging.info('Saving %s.' % filename)
                net.save_parameters(filename)
            logging.info('New best validation: R@1: %f NMI: %f' %
                         (best_results[0][1], best_results[-1][1]))

    return best_results
Exemple #9
0
def train(net, opt, train_dataloader, val_dataloader, context, run_id):
    """Training function."""

    if not opt.skip_pretrain_validation:
        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0)
        for name, val_acc in validation_results:
            logging.info('Pre-train validation: %s=%f' % (name, val_acc))

    steps = parse_steps(opt.steps, opt.epochs, logging)

    opt_options = {'learning_rate': opt.lr, 'wd': opt.wd}
    if opt.optimizer == 'sgd':
        opt_options['momentum'] = 0.9
    if opt.optimizer == 'adam':
        opt_options['epsilon'] = 1e-7

    if opt.decrease_cnn_lr:
        logging.info('Setting embedding LR to %f' % (10.0 * opt.lr))
        for p, v in net.encoder.collect_params().items():
            v.lr_mult = 10.0

    trainer = mx.gluon.Trainer(net.collect_params(),
                               opt.optimizer,
                               opt_options,
                               kvstore=opt.kvstore)

    L = NPairsLoss(l2_reg=opt.l2reg_weight, symmetric=opt.symmetric_loss)
    if not opt.disable_hybridize:
        L.hybridize()

    smoothing_constant = .01  # for tracking moving losses
    moving_loss = 0
    best_results = []  # R@1, NMI

    for epoch in range(1, opt.epochs + 1):
        p_bar = tqdm(enumerate(train_dataloader),
                     total=len(train_dataloader),
                     desc=('[Run %d/%d] Epoch %d' %
                           (run_id, opt.number_of_runs, epoch)))
        trainer.set_learning_rate(get_lr(opt.lr, epoch, steps, opt.factor))

        for i, batch in p_bar:
            anchors_batch = batch[0][0]  # <N x I>
            positives_batch = batch[1][0]  # <N x I>

            anchors = mx.gluon.utils.split_and_load(anchors_batch,
                                                    ctx_list=context,
                                                    batch_axis=0)
            positives = mx.gluon.utils.split_and_load(positives_batch,
                                                      ctx_list=context,
                                                      batch_axis=0)
            labels_batch = mx.gluon.utils.split_and_load(batch[2][0],
                                                         ctx_list=context,
                                                         batch_axis=0)
            anchor_embs = []
            positives_embs = []

            with ag.record():
                for a, p in zip(anchors, positives):
                    a_emb = net(a)
                    p_emb = net(p)
                    anchor_embs.append(a_emb)
                    positives_embs.append(p_emb)
                anchors = mx.nd.concat(*anchor_embs, dim=0)
                positives = mx.nd.concat(*positives_embs, dim=0)

                loss = L(anchors, positives, labels_batch[0])

            loss.backward()
            trainer.step(opt.batch_size / 2)

            curr_loss = mx.nd.mean(loss).asscalar()
            moving_loss = (
                curr_loss if ((i == 0) and (epoch == 1))  # starting value
                else (1 - smoothing_constant) * moving_loss +
                smoothing_constant * curr_loss)  # add current
            p_bar.set_postfix_str('Moving loss: %.4f' % moving_loss)

        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0)
        for name, val_acc in validation_results:
            logging.info('[Epoch %d] validation: %s=%f' %
                         (epoch, name, val_acc))

        if (len(best_results)
                == 0) or (validation_results[0][1] > best_results[0][1]):
            best_results = validation_results
            if opt.save_model_prefix.lower() != 'none':
                filename = '%s.params' % opt.save_model_prefix
                logging.info('Saving %s.' % filename)
                net.save_parameters(filename)
            logging.info('New best validation: R@1: %f NMI: %f' %
                         (best_results[0][1], best_results[-1][1]))

    return best_results
Exemple #10
0
def train(net, beta, opt, train_dataloader, val_dataloader, batch_size,
          context, run_id):
    """Training function."""

    if not opt.skip_pretrain_validation:
        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0)
        for name, val_acc in validation_results:
            logging.info('Pre-train validation: %s=%f' % (name, val_acc))

    steps = parse_steps(opt.steps, opt.epochs, logging)

    opt_options = {'learning_rate': opt.lr, 'wd': opt.wd}
    if opt.optimizer == 'sgd':
        opt_options['momentum'] = 0.9
    if opt.optimizer == 'adam':
        opt_options['epsilon'] = 1e-7

    trainer = gluon.Trainer(net.collect_params(),
                            opt.optimizer,
                            opt_options,
                            kvstore=opt.kvstore)

    train_beta = not isinstance(beta, float)

    if train_beta:
        # Jointly train class-specific beta
        beta.initialize(mx.init.Constant(opt.beta), ctx=context)
        trainer_beta = gluon.Trainer(beta.collect_params(),
                                     'sgd', {
                                         'learning_rate': opt.lr_beta,
                                         'momentum': 0.9
                                     },
                                     kvstore=opt.kvstore)
    loss = MarginLoss(batch_size,
                      opt.batch_k,
                      beta,
                      margin=opt.margin,
                      nu=opt.nu,
                      train_beta=train_beta)
    if not opt.disable_hybridize:
        loss.hybridize()

    best_results = []  # R@1, NMI

    for epoch in range(1, opt.epochs + 1):
        prev_loss, cumulative_loss = 0.0, 0.0

        # Learning rate schedule.
        trainer.set_learning_rate(get_lr(opt.lr, epoch, steps, opt.factor))
        logging.info('Epoch %d learning rate=%f', epoch, trainer.learning_rate)
        if train_beta:
            trainer_beta.set_learning_rate(
                get_lr(opt.lr_beta, epoch, steps, opt.factor))
            logging.info('Epoch %d beta learning rate=%f', epoch,
                         trainer_beta.learning_rate)

        p_bar = tqdm(train_dataloader,
                     desc='[Run %d/%d] Epoch %d' %
                     (run_id, opt.number_of_runs, epoch),
                     total=opt.iteration_per_epoch)
        for batch in p_bar:
            data = gluon.utils.split_and_load(batch[0][0],
                                              ctx_list=context,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1][0].astype('float32'),
                                               ctx_list=context,
                                               batch_axis=0)

            Ls = []
            with ag.record():
                for x, y in zip(data, label):
                    embedings = net(x)
                    L = loss(embedings, y)

                    Ls.append(L)
                    cumulative_loss += nd.mean(L).asscalar()

                for L in Ls:
                    L.backward()

            trainer.step(batch[0].shape[1])
            if opt.lr_beta > 0.0:
                trainer_beta.step(batch[0].shape[1])

            p_bar.set_postfix({'loss': cumulative_loss - prev_loss})
            prev_loss = cumulative_loss

        logging.info('[Epoch %d] training loss=%f' % (epoch, cumulative_loss))

        validation_results = validate(net,
                                      val_dataloader,
                                      context,
                                      use_threads=opt.num_workers > 0)
        for name, val_acc in validation_results:
            logging.info('[Epoch %d] validation: %s=%f' %
                         (epoch, name, val_acc))

        if (len(best_results)
                == 0) or (validation_results[0][1] > best_results[0][1]):
            best_results = validation_results
            if opt.save_model_prefix.lower() != 'none':
                filename = '%s.params' % opt.save_model_prefix
                logging.info('Saving %s.' % filename)
                net.save_parameters(filename)
            logging.info('New best validation: R@1: %f NMI: %f' %
                         (best_results[0][1], best_results[-1][1]))

    return best_results