Ejemplo n.º 1
0
    def __init__(self, use_float16=False):
        self._transform_test = transforms.Compose([transforms.ToTensor()])

        self._transform_train = transforms.Compose([
            transforms.RandomBrightness(0.3),
            transforms.RandomContrast(0.3),
            transforms.RandomSaturation(0.3),
            transforms.RandomFlipLeftRight(),
            transforms.ToTensor()
        ])
        self.use_float16 = use_float16
Ejemplo n.º 2
0
def test_transformer():
    from mxnet.gluon.data.vision import transforms

    transform = transforms.Compose([
		transforms.Resize(300),
		transforms.CenterCrop(256),
		transforms.RandomResizedCrop(224),
		transforms.RandomFlipLeftRight(),
		transforms.RandomColorJitter(0.1, 0.1, 0.1, 0.1),
		transforms.RandomBrightness(0.1),
		transforms.RandomContrast(0.1),
		transforms.RandomSaturation(0.1),
		transforms.RandomHue(0.1),
		transforms.RandomLighting(0.1),
		transforms.ToTensor(),
		transforms.Normalize([0, 0, 0], [1, 1, 1])])

    transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read()
Ejemplo n.º 3
0
logger.setLevel(logging.INFO)
logger.addHandler(filehandler)
logger.addHandler(streamhandler)

logger.info(opt)
if opt.dataset == 'emore' and opt.batch_size < 512:
    logger.info("Warning: If you train a model on emore with batch size < 512 may lead to not converge."
                "You may try a smaller dataset.")

transform_test = transforms.Compose([
    transforms.ToTensor()
])

_transform_train = transforms.Compose([
    transforms.RandomBrightness(0.3),
    transforms.RandomContrast(0.3),
    transforms.RandomSaturation(0.3),
    transforms.RandomFlipLeftRight(),
    transforms.ToTensor()
])


def transform_train(data, label):
    im = _transform_train(data)
    return im, label


def inf_train_gen(loader):
    while True:
        for batch in loader:
            yield batch
def main():
    opt = parse_args()
    batch_size = opt.batch_size
    classes = 10

    log_dir = os.path.join(opt.save_dir, "logs")
    model_dir = os.path.join(opt.save_dir, "params")
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    # Init dataloader
    jitter_param = 0.4
    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.RandomBrightness(jitter_param),
        transforms.RandomColorJitter(jitter_param),
        transforms.RandomContrast(jitter_param),
        transforms.RandomSaturation(jitter_param),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    train_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
        batch_size=batch_size,
        shuffle=True,
        last_batch='discard',
        num_workers=opt.num_workers)

    val_data = gluon.data.DataLoader(
        gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
        batch_size=batch_size,
        shuffle=False,
        num_workers=opt.num_workers)

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]

    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]

    model_name = opt.model
    model_name = opt.model
    if model_name.startswith('cifar_wideresnet'):
        kwargs = {'classes': classes, 'drop_rate': opt.drop_rate}
    else:
        kwargs = {'classes': classes}
    net = get_model(model_name, **kwargs)

    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx=context)
    optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    def test(ctx, val_loader):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_loader):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(train_data, val_data, epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]

        net.hybridize()
        net.initialize(mx.init.Xavier(), ctx=ctx)
        net.forward(mx.nd.ones((1, 3, 30, 30), ctx=ctx[0]))
        with SummaryWriter(logdir=log_dir, verbose=False) as sw:
            sw.add_graph(net)

        trainer = gluon.Trainer(net.collect_params(), optimizer, {
            'learning_rate': opt.lr,
            'wd': opt.wd,
            'momentum': opt.momentum
        })
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0
        global_step = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                lr_decay_count += 1

            tbar = tqdm(train_data)

            for i, batch in enumerate(tbar):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1
                global_step += len(loss)

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('{}/{}-{}-{:04.3f}-best.params'.format(
                    model_dir, model_name, epoch, best_val_score))

            with SummaryWriter(logdir=log_dir, verbose=False) as sw:
                sw.add_scalar(tag="TrainLos",
                              value=train_loss,
                              global_step=global_step)
                sw.add_scalar(tag="TrainAcc",
                              value=acc,
                              global_step=global_step)
                sw.add_scalar(tag="ValAcc",
                              value=val_acc,
                              global_step=global_step)
                sw.add_graph(net)

            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                         (epoch, acc, val_acc, train_loss, time.time() - tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('{}/{}-{}.params'.format(
                    save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('{}/{}-{}.params'.format(
                save_dir, model_name, epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize()
    train(train_data, val_data, opt.num_epochs, context)