def __init__(self, use_float16=False): self._transform_test = transforms.Compose([transforms.ToTensor()]) self._transform_train = transforms.Compose([ transforms.RandomBrightness(0.3), transforms.RandomContrast(0.3), transforms.RandomSaturation(0.3), transforms.RandomFlipLeftRight(), transforms.ToTensor() ]) self.use_float16 = use_float16
def test_transformer(): from mxnet.gluon.data.vision import transforms transform = transforms.Compose([ transforms.Resize(300), transforms.CenterCrop(256), transforms.RandomResizedCrop(224), transforms.RandomFlipLeftRight(), transforms.RandomColorJitter(0.1, 0.1, 0.1, 0.1), transforms.RandomBrightness(0.1), transforms.RandomContrast(0.1), transforms.RandomSaturation(0.1), transforms.RandomHue(0.1), transforms.RandomLighting(0.1), transforms.ToTensor(), transforms.Normalize([0, 0, 0], [1, 1, 1])]) transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read()
logger.setLevel(logging.INFO) logger.addHandler(filehandler) logger.addHandler(streamhandler) logger.info(opt) if opt.dataset == 'emore' and opt.batch_size < 512: logger.info("Warning: If you train a model on emore with batch size < 512 may lead to not converge." "You may try a smaller dataset.") transform_test = transforms.Compose([ transforms.ToTensor() ]) _transform_train = transforms.Compose([ transforms.RandomBrightness(0.3), transforms.RandomContrast(0.3), transforms.RandomSaturation(0.3), transforms.RandomFlipLeftRight(), transforms.ToTensor() ]) def transform_train(data, label): im = _transform_train(data) return im, label def inf_train_gen(loader): while True: for batch in loader: yield batch
def main(): opt = parse_args() batch_size = opt.batch_size classes = 10 log_dir = os.path.join(opt.save_dir, "logs") model_dir = os.path.join(opt.save_dir, "params") if not os.path.exists(model_dir): os.makedirs(model_dir) # Init dataloader jitter_param = 0.4 transform_train = transforms.Compose([ gcv_transforms.RandomCrop(32, pad=4), transforms.RandomFlipLeftRight(), transforms.RandomBrightness(jitter_param), transforms.RandomColorJitter(jitter_param), transforms.RandomContrast(jitter_param), transforms.RandomSaturation(jitter_param), transforms.ToTensor(), transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) ]) train_data = gluon.data.DataLoader( gluon.data.vision.CIFAR10(train=True).transform_first(transform_train), batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=opt.num_workers) val_data = gluon.data.DataLoader( gluon.data.vision.CIFAR10(train=False).transform_first(transform_test), batch_size=batch_size, shuffle=False, num_workers=opt.num_workers) num_gpus = opt.num_gpus batch_size *= max(1, num_gpus) context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()] lr_decay = opt.lr_decay lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf] model_name = opt.model model_name = opt.model if model_name.startswith('cifar_wideresnet'): kwargs = {'classes': classes, 'drop_rate': opt.drop_rate} else: kwargs = {'classes': classes} net = get_model(model_name, **kwargs) if opt.resume_from: net.load_parameters(opt.resume_from, ctx=context) optimizer = 'nag' save_period = opt.save_period if opt.save_dir and save_period: save_dir = opt.save_dir makedirs(save_dir) else: save_dir = '' save_period = 0 def test(ctx, val_loader): metric = mx.metric.Accuracy() for i, batch in enumerate(val_loader): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) outputs = [net(X) for X in data] metric.update(label, outputs) return metric.get() def train(train_data, val_data, epochs, ctx): if isinstance(ctx, mx.Context): ctx = [ctx] net.hybridize() net.initialize(mx.init.Xavier(), ctx=ctx) net.forward(mx.nd.ones((1, 3, 30, 30), ctx=ctx[0])) with SummaryWriter(logdir=log_dir, verbose=False) as sw: sw.add_graph(net) trainer = gluon.Trainer(net.collect_params(), optimizer, { 'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum }) metric = mx.metric.Accuracy() train_metric = mx.metric.Accuracy() loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() iteration = 0 lr_decay_count = 0 best_val_score = 0 global_step = 0 for epoch in range(epochs): tic = time.time() train_metric.reset() metric.reset() train_loss = 0 num_batch = len(train_data) alpha = 1 if epoch == lr_decay_epoch[lr_decay_count]: trainer.set_learning_rate(trainer.learning_rate * lr_decay) lr_decay_count += 1 tbar = tqdm(train_data) for i, batch in enumerate(tbar): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) with ag.record(): output = [net(X) for X in data] loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)] for l in loss: l.backward() trainer.step(batch_size) train_loss += sum([l.sum().asscalar() for l in loss]) train_metric.update(label, output) name, acc = train_metric.get() iteration += 1 global_step += len(loss) train_loss /= batch_size * num_batch name, acc = train_metric.get() name, val_acc = test(ctx, val_data) if val_acc > best_val_score: best_val_score = val_acc net.save_parameters('{}/{}-{}-{:04.3f}-best.params'.format( model_dir, model_name, epoch, best_val_score)) with SummaryWriter(logdir=log_dir, verbose=False) as sw: sw.add_scalar(tag="TrainLos", value=train_loss, global_step=global_step) sw.add_scalar(tag="TrainAcc", value=acc, global_step=global_step) sw.add_scalar(tag="ValAcc", value=val_acc, global_step=global_step) sw.add_graph(net) logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' % (epoch, acc, val_acc, train_loss, time.time() - tic)) if save_period and save_dir and (epoch + 1) % save_period == 0: net.save_parameters('{}/{}-{}.params'.format( save_dir, model_name, epoch)) if save_period and save_dir: net.save_parameters('{}/{}-{}.params'.format( save_dir, model_name, epochs - 1)) if opt.mode == 'hybrid': net.hybridize() train(train_data, val_data, opt.num_epochs, context)