def wrapper(*args, **kwargs): def run(env): subprocess.check_call(args=["bpslaunch"], shell=True, stdout=sys.stdout, stderr=sys.stderr, env=env) print("bps init") scheduler = threading.Thread(target=run, args=(cls.SCHEDULER_ENV, )) server = threading.Thread(target=run, args=(cls.SERVER_ENV, )) scheduler.daemon = True server.daemon = True scheduler.start() server.start() bps.init() func(*args, **kwargs) bps.shutdown() scheduler.join() server.join() print("bps shutdown") time.sleep(2)
def test_byteps_broadcast(self): """Test that the broadcast correctly broadcasts 1D, 2D, 3D tensors.""" bps.init() rank = bps.rank() size = bps.size() # This test does not apply if there is only one worker. if size == 1: return dtypes = ['int32', 'int64', 'float32', 'float64'] dims = [1, 2, 3] ctx = self._current_context() count = 0 shapes = [(), (17), (17, 17), (17, 17, 17)] root_ranks = list(range(size)) for dtype, dim, root_rank in itertools.product(dtypes, dims, root_ranks): tensor = mx.nd.ones(shapes[dim], ctx=ctx) * rank root_tensor = mx.nd.ones(shapes[dim], ctx=ctx) * root_rank tensor = tensor.astype(dtype) root_tensor = root_tensor.astype(dtype) broadcast_tensor = bps.broadcast(tensor, root_rank=root_rank, name=str(count)) if rank != root_rank: if same(tensor.asnumpy(), root_tensor.asnumpy()): print("broadcast", count, dtype, dim, mx.nd.max(tensor == root_tensor)) print("tensor", bps.rank(), tensor) print("root_tensor", bps.rank(), root_tensor) print("comparison", bps.rank(), tensor == root_tensor) assert not same(tensor.asnumpy(), root_tensor.asnumpy()), \ 'bps.broadcast modifies source tensor' if not same(broadcast_tensor.asnumpy(), root_tensor.asnumpy()): print("broadcast", count, dtype, dim) print("broadcast_tensor", bps.rank(), broadcast_tensor) print("root_tensor", bps.rank(), root_tensor) print("comparison", bps.rank(), broadcast_tensor == root_tensor) assert same(broadcast_tensor.asnumpy(), root_tensor.asnumpy()), \ 'bps.broadcast produces incorrect broadcasted tensor'
def init_comm(backend): """Init communication backend""" # backend specific implementation if backend == 'horovod': try: import horovod.mxnet as hvd except ImportError: logging.info('horovod must be installed.') exit() hvd.init() store = None num_workers = hvd.size() rank = hvd.rank() local_rank = hvd.local_rank() is_master_node = rank == local_rank ctxs = [mx.gpu(local_rank)] elif backend == 'byteps': try: import byteps.mxnet as bps except ImportError: logging.info('BytePS must be installed.') exit() bps.init() store = None num_workers = bps.size() rank = bps.rank() local_rank = bps.local_rank() is_master_node = rank == local_rank ctxs = [mx.gpu(local_rank)] else: # kvstore store = mx.kv.create(backend) num_workers = store.num_workers rank = store.rank local_rank = 0 is_master_node = rank == local_rank ctxs = [mx.cpu()] if args.gpus is None or args.gpus == '' else \ [mx.gpu(int(x)) for x in args.gpus.split(',')] return store, num_workers, rank, local_rank, is_master_node, ctxs
def test_byteps_push_pull_inplace(self): """Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors.""" bps.init() size = bps.size() dtypes = self.filter_supported_types(['int32', 'int64', 'float32', 'float64']) dims = [1, 2, 3] ctx = self._current_context() count = 0 shapes = [(), (17), (17, 17), (17, 17, 17)] for dtype, dim in itertools.product(dtypes, dims): mx.random.seed(1234, ctx=ctx) tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim], ctx=ctx) tensor = tensor.astype(dtype) multiplied = tensor * size bps.byteps_declare_tensor(tensor, "tensor_" + str(count)) bps.byteps_push_pull(tensor, name= "tensor_" + str(count)) max_difference = mx.nd.max(mx.nd.subtract(tensor, multiplied)) count += 1 # Threshold for floating point equality depends on number of # ranks, since we're comparing against precise multiplication. if size <= 3 or dtype in ['int32', 'int64']: threshold = 0 elif size < 10: threshold = 1e-4 elif size < 15: threshold = 5e-4 else: break if max_difference > threshold: print("self", count, dtype, dim, max_difference, threshold) print("tensor", bps.rank(), tensor) print("multiplied", bps.rank(), multiplied) assert max_difference <= threshold, 'bps.byteps_push_pull produces \
def test_byteps_push_pull(self): """Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors.""" bps.init() size = bps.size() dtypes = self.filter_supported_types(['float32']) dims = [1] ctx = self._current_context() count = 0 shapes = [(), (17)] for dtype, dim in itertools.product(dtypes, dims): # MXNet uses gpu_id as part of the seed, so to get identical seeds # we must set a context. mx.random.seed(10 + 10 * bps.rank(), ctx=ctx) tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim], ctx=ctx) tensor = tensor.astype(dtype) print("tensor before push_pull:", tensor) bps.byteps_declare_tensor(tensor, "tensor_" + str(count)) bps.byteps_push_pull(tensor, name="tensor_"+str(count)) tensor.wait_to_read() print("tensor after push_pull:", tensor) bps.shutdown()
def main(): opt = parse_args() bps.init() gpu_name = subprocess.check_output( ['nvidia-smi', '--query-gpu=gpu_name', '--format=csv']) gpu_name = gpu_name.decode('utf8').split('\n')[-2] gpu_name = '-'.join(gpu_name.split()) filename = "cifar100-%d-%s-%s.log" % (bps.size(), gpu_name, opt.logging_file) filehandler = logging.FileHandler(filename) streamhandler = logging.StreamHandler() logger = logging.getLogger('') logger.setLevel(logging.INFO) logger.addHandler(filehandler) logger.addHandler(streamhandler) logger.info(opt) batch_size = opt.batch_size classes = 100 num_gpus = opt.num_gpus # batch_size *= max(1, num_gpus) context = mx.gpu(bps.local_rank()) if num_gpus > 0 else mx.cpu( bps.local_rank()) num_workers = opt.num_workers nworker = bps.size() rank = bps.rank() lr_decay = opt.lr_decay lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] num_batches = 50000 // (opt.batch_size * nworker) lr_scheduler = LRSequential([ LRScheduler('linear', base_lr=opt.warmup_lr, target_lr=opt.lr * nworker / bps.local_size(), nepochs=opt.warmup_epochs, iters_per_epoch=num_batches), LRScheduler('step', base_lr=opt.lr * nworker / bps.local_size(), target_lr=0, nepochs=opt.num_epochs - opt.warmup_epochs, iters_per_epoch=num_batches, step_epoch=lr_decay_epoch, step_factor=lr_decay, power=2) ]) num_batches = 50000 // (opt.batch_size * nworker) lr_scheduler = LRSequential([ LRScheduler('linear', base_lr=opt.warmup_lr, target_lr=opt.lr * nworker / bps.local_size(), nepochs=opt.warmup_epochs, iters_per_epoch=num_batches), LRScheduler('step', base_lr=opt.lr * nworker / bps.local_size(), target_lr=0, nepochs=opt.num_epochs - opt.warmup_epochs, iters_per_epoch=num_batches, step_epoch=lr_decay_epoch, step_factor=lr_decay, power=2) ]) model_name = opt.model if model_name.startswith('cifar_wideresnet'): kwargs = {'classes': classes, 'drop_rate': opt.drop_rate} else: kwargs = {'classes': classes} net = get_model(model_name, **kwargs) if opt.resume_from: net.load_parameters(opt.resume_from, ctx=context) if opt.compressor: optimizer = 'sgd' else: optimizer = 'nag' save_period = opt.save_period if opt.save_dir and save_period: save_dir = opt.save_dir makedirs(save_dir) else: save_dir = '' save_period = 0 # from https://github.com/weiaicunzai/pytorch-cifar/blob/master/conf/global_settings.py CIFAR100_TRAIN_MEAN = [ 0.5070751592371323, 0.48654887331495095, 0.4409178433670343 ] CIFAR100_TRAIN_STD = [ 0.2673342858792401, 0.2564384629170883, 0.27615047132568404 ] transform_train = transforms.Compose([ gcv_transforms.RandomCrop(32, pad=4), transforms.RandomFlipLeftRight(), transforms.ToTensor(), transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD) ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD) ]) def test(ctx, val_data): metric = mx.metric.Accuracy() for i, batch in enumerate(val_data): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) outputs = [net(X) for X in data] metric.update(label, outputs) return metric.get() def train(epochs, ctx): if isinstance(ctx, mx.Context): ctx = [ctx] net.initialize(mx.init.Xavier(), ctx=ctx) train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100( train=True).shard(nworker, rank).transform_first(transform_train), batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers) val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100( train=False).shard(nworker, rank).transform_first(transform_test), batch_size=batch_size, shuffle=False, num_workers=num_workers) params = net.collect_params() compression_params = { "compressor": opt.compressor, "ef": opt.ef, "momentum": opt.compress_momentum, "scaling": opt.onebit_scaling, "k": opt.k, "fp16": opt.fp16_pushpull } optimizer_params = { 'lr_scheduler': lr_scheduler, 'wd': opt.wd, 'momentum': opt.momentum } trainer = bps.DistributedTrainer(params, optimizer, optimizer_params, compression_params=compression_params) metric = mx.metric.Accuracy() train_metric = mx.metric.Accuracy() loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() iteration = 0 best_val_score = 0 bps.byteps_declare_tensor("acc") for epoch in range(epochs): tic = time.time() train_metric.reset() metric.reset() train_loss = 0 num_batch = len(train_data) for i, batch in enumerate(train_data): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) with ag.record(): output = [net(X) for X in data] loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)] for l in loss: l.backward() trainer.step(batch_size) train_loss += sum([l.sum().asscalar() for l in loss]) train_metric.update(label, output) name, train_acc = train_metric.get() iteration += 1 train_loss /= batch_size * num_batch name, train_acc = train_metric.get() throughput = int(batch_size * nworker * i / (time.time() - tic)) logger.info( '[Epoch %d] speed: %d samples/sec\ttime cost: %f lr=%f' % (epoch, throughput, time.time() - tic, trainer.learning_rate)) name, val_acc = test(ctx, val_data) acc = mx.nd.array([train_acc, val_acc], ctx=ctx[0]) bps.byteps_push_pull(acc, name="acc", is_average=False) acc /= bps.size() train_acc, val_acc = acc[0].asscalar(), acc[1].asscalar() if bps.rank() == 0: logger.info('[Epoch %d] training: %s=%f' % (epoch, name, train_acc)) logger.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc)) if val_acc > best_val_score: best_val_score = val_acc net.save_parameters( '%s/%.4f-cifar-%s-%d-best.params' % (save_dir, best_val_score, model_name, epoch)) if save_period and save_dir and (epoch + 1) % save_period == 0: net.save_parameters('%s/cifar100-%s-%d.params' % (save_dir, model_name, epoch)) if save_period and save_dir: net.save_parameters('%s/cifar100-%s-%d.params' % (save_dir, model_name, epochs - 1)) if opt.mode == 'hybrid': net.hybridize() train(opt.num_epochs, context)
def main(): opt = parse_args() bps.init() gpu_name = subprocess.check_output( ['nvidia-smi', '--query-gpu=gpu_name', '--format=csv']) gpu_name = gpu_name.decode('utf8').split('\n')[-2] gpu_name = '-'.join(gpu_name.split()) filename = "imagenet-%d-%s-%s.log" % (bps.size(), gpu_name, opt.logging_file) filehandler = logging.FileHandler(filename) streamhandler = logging.StreamHandler() logger = logging.getLogger('') logger.setLevel(logging.INFO) logger.addHandler(filehandler) logger.addHandler(streamhandler) logger.info(opt) batch_size = opt.batch_size classes = 1000 num_training_samples = 1281167 num_gpus = opt.num_gpus # batch_size *= max(1, num_gpus) context = mx.gpu(bps.local_rank()) if num_gpus > 0 else mx.cpu( bps.local_rank()) num_workers = opt.num_workers nworker = bps.size() rank = bps.rank() lr_decay = opt.lr_decay lr_decay_period = opt.lr_decay_period if opt.lr_decay_period > 0: lr_decay_epoch = list( range(lr_decay_period, opt.num_epochs, lr_decay_period)) else: lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch] num_batches = num_training_samples // (batch_size * nworker) lr_scheduler = LRSequential([ LRScheduler('linear', base_lr=opt.warmup_lr, target_lr=opt.lr * nworker / bps.local_size(), nepochs=opt.warmup_epochs, iters_per_epoch=num_batches), LRScheduler(opt.lr_mode, base_lr=opt.lr * nworker / bps.local_size(), target_lr=0, nepochs=opt.num_epochs - opt.warmup_epochs, iters_per_epoch=num_batches, step_epoch=lr_decay_epoch, step_factor=lr_decay, power=2) ]) model_name = opt.model kwargs = { 'ctx': context, 'pretrained': opt.use_pretrained, 'classes': classes } if opt.use_gn: from gluoncv.nn import GroupNorm kwargs['norm_layer'] = GroupNorm if model_name.startswith('vgg'): kwargs['batch_norm'] = opt.batch_norm elif model_name.startswith('resnext'): kwargs['use_se'] = opt.use_se if opt.last_gamma: kwargs['last_gamma'] = True if opt.compressor: optimizer = 'sgd' else: optimizer = 'nag' optimizer_params = { 'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler } if opt.dtype != 'float32': optimizer_params['multi_precision'] = True net = get_model(model_name, **kwargs) net.cast(opt.dtype) if opt.resume_params is not '': net.load_parameters(opt.resume_params, ctx=context) # teacher model for distillation training if opt.teacher is not None and opt.hard_weight < 1.0: teacher_name = opt.teacher teacher = get_model(teacher_name, pretrained=True, classes=classes, ctx=context) teacher.cast(opt.dtype) distillation = True else: distillation = False # Two functions for reading data from record file or raw images def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num_workers): rec_train = os.path.expanduser(rec_train) rec_train_idx = os.path.expanduser(rec_train_idx) rec_val = os.path.expanduser(rec_val) rec_val_idx = os.path.expanduser(rec_val_idx) jitter_param = 0.4 lighting_param = 0.1 input_size = opt.input_size crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875 resize = int(math.ceil(input_size / crop_ratio)) mean_rgb = [123.68, 116.779, 103.939] std_rgb = [58.393, 57.12, 57.375] def batch_fn(batch, ctx): data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) return data, label train_data = mx.io.ImageRecordIter(path_imgrec=rec_train, path_imgidx=rec_train_idx, preprocess_threads=num_workers, shuffle=True, batch_size=batch_size, data_shape=(3, input_size, input_size), mean_r=mean_rgb[0], mean_g=mean_rgb[1], mean_b=mean_rgb[2], std_r=std_rgb[0], std_g=std_rgb[1], std_b=std_rgb[2], rand_mirror=True, random_resized_crop=True, max_aspect_ratio=4. / 3., min_aspect_ratio=3. / 4., max_random_area=1, min_random_area=0.08, brightness=jitter_param, saturation=jitter_param, contrast=jitter_param, pca_noise=lighting_param, num_parts=nworker, part_index=rank) val_data = mx.io.ImageRecordIter(path_imgrec=rec_val, path_imgidx=rec_val_idx, preprocess_threads=num_workers, shuffle=False, batch_size=batch_size, resize=resize, data_shape=(3, input_size, input_size), mean_r=mean_rgb[0], mean_g=mean_rgb[1], mean_b=mean_rgb[2], std_r=std_rgb[0], std_g=std_rgb[1], std_b=std_rgb[2], num_parts=nworker, part_index=rank) return train_data, val_data, batch_fn def get_data_loader(data_dir, batch_size, num_workers): normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) jitter_param = 0.4 lighting_param = 0.1 input_size = opt.input_size crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875 resize = int(math.ceil(input_size / crop_ratio)) def batch_fn(batch, ctx): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) return data, label transform_train = transforms.Compose([ transforms.RandomResizedCrop(input_size), transforms.RandomFlipLeftRight(), transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param, saturation=jitter_param), transforms.RandomLighting(lighting_param), transforms.ToTensor(), normalize ]) transform_test = transforms.Compose([ transforms.Resize(resize, keep_ratio=True), transforms.CenterCrop(input_size), transforms.ToTensor(), normalize ]) train_data = gluon.data.DataLoader(imagenet.classification.ImageNet( data_dir, train=True).transform_first(transform_train), batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers) val_data = gluon.data.DataLoader(imagenet.classification.ImageNet( data_dir, train=False).transform_first(transform_test), batch_size=batch_size, shuffle=False, num_workers=num_workers) return train_data, val_data, batch_fn if opt.use_rec: train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_train_idx, opt.rec_val, opt.rec_val_idx, batch_size, num_workers) else: train_data, val_data, batch_fn = get_data_loader( opt.data_dir, batch_size, num_workers) if opt.mixup: train_metric = mx.metric.RMSE() else: train_metric = mx.metric.Accuracy() acc_top1 = mx.metric.Accuracy() acc_top5 = mx.metric.TopKAccuracy(5) save_frequency = opt.save_frequency if opt.save_dir and save_frequency: save_dir = opt.save_dir makedirs(save_dir) else: save_dir = '' save_frequency = 0 def mixup_transform(label, classes, lam=1, eta=0.0): if isinstance(label, nd.NDArray): label = [label] res = [] for l in label: y1 = l.one_hot(classes, on_value=1 - eta + eta / classes, off_value=eta / classes) y2 = l[::-1].one_hot(classes, on_value=1 - eta + eta / classes, off_value=eta / classes) res.append(lam * y1 + (1 - lam) * y2) return res def smooth(label, classes, eta=0.1): if isinstance(label, nd.NDArray): label = [label] smoothed = [] for l in label: res = l.one_hot(classes, on_value=1 - eta + eta / classes, off_value=eta / classes) smoothed.append(res) return smoothed def test(ctx, val_data): if opt.use_rec: val_data.reset() acc_top1.reset() acc_top5.reset() for i, batch in enumerate(val_data): data, label = batch_fn(batch, ctx) outputs = [net(X.astype(opt.dtype, copy=False)) for X in data] acc_top1.update(label, outputs) acc_top5.update(label, outputs) _, top1 = acc_top1.get() _, top5 = acc_top5.get() return (1 - top1, 1 - top5) def train(ctx): if isinstance(ctx, mx.Context): ctx = [ctx] if opt.resume_params is '': net.initialize(mx.init.MSRAPrelu(), ctx=ctx) if opt.no_wd: for k, v in net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 compression_params = { "compressor": opt.compressor, "ef": opt.ef, "momentum": opt.compress_momentum, "scaling": opt.onebit_scaling, "k": opt.k } trainer = bps.DistributedTrainer(net.collect_params(), optimizer, optimizer_params, compression_params=compression_params) if opt.resume_states is not '': trainer.load_states(opt.resume_states) if opt.label_smoothing or opt.mixup: sparse_label_loss = False else: sparse_label_loss = True if distillation: L = gcv.loss.DistillationSoftmaxCrossEntropyLoss( temperature=opt.temperature, hard_weight=opt.hard_weight, sparse_label=sparse_label_loss) else: L = gluon.loss.SoftmaxCrossEntropyLoss( sparse_label=sparse_label_loss) best_val_score = 1 # bps.byteps_declare_tensor("acc") for epoch in range(opt.resume_epoch, opt.num_epochs): tic = time.time() if opt.use_rec: train_data.reset() train_metric.reset() btic = time.time() for i, batch in enumerate(train_data): data, label = batch_fn(batch, ctx) if opt.mixup: lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha) if epoch >= opt.num_epochs - opt.mixup_off_epoch: lam = 1 data = [lam * X + (1 - lam) * X[::-1] for X in data] if opt.label_smoothing: eta = 0.1 else: eta = 0.0 label = mixup_transform(label, classes, lam, eta) elif opt.label_smoothing: hard_label = label label = smooth(label, classes) if distillation: teacher_prob = [ nd.softmax( teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) for X in data ] with ag.record(): outputs = [ net(X.astype(opt.dtype, copy=False)) for X in data ] if distillation: loss = [ L(yhat.astype('float32', copy=False), y.astype('float32', copy=False), p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob) ] else: loss = [ L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label) ] for l in loss: l.backward() trainer.step(batch_size) if opt.mixup: output_softmax = [ nd.SoftmaxActivation(out.astype('float32', copy=False)) for out in outputs ] train_metric.update(label, output_softmax) else: if opt.label_smoothing: train_metric.update(hard_label, outputs) else: train_metric.update(label, outputs) if opt.log_interval and not (i + 1) % opt.log_interval: train_metric_name, train_metric_score = train_metric.get() logger.info( 'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f\ttime=%f' % (epoch, i, batch_size * nworker * opt.log_interval / (time.time() - btic), train_metric_name, train_metric_score, trainer.learning_rate, time.time() - btic)) btic = time.time() train_metric_name, train_metric_score = train_metric.get() throughput = int(batch_size * nworker * i / (time.time() - tic)) logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' % (epoch, throughput, time.time() - tic)) err_top1_val, err_top5_val = test(ctx, val_data) # acc = mx.nd.array([train_metric_score, err_top1_val, err_top5_val], # ctx=ctx[0]) # bps.byteps_push_pull(acc, name="acc", is_average=False) # acc /= bps.size() # train_metric_score, err_top1_val, err_top5_val = acc[0].asscalar( # ), acc[1].asscalar(), acc[2].asscalar() # if bps.rank() == 0: logger.info('[Epoch %d] training: %s=%f' % (epoch, train_metric_name, train_metric_score)) logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f' % (epoch, err_top1_val, err_top5_val)) if err_top1_val < best_val_score: best_val_score = err_top1_val net.save_parameters( '%s/%.4f-imagenet-%s-%d-best.params' % (save_dir, best_val_score, model_name, epoch)) trainer.save_states( '%s/%.4f-imagenet-%s-%d-best.states' % (save_dir, best_val_score, model_name, epoch)) if save_frequency and save_dir and (epoch + 1) % save_frequency == 0: net.save_parameters('%s/imagenet-%s-%d.params' % (save_dir, model_name, epoch)) trainer.save_states('%s/imagenet-%s-%d.states' % (save_dir, model_name, epoch)) if save_frequency and save_dir: net.save_parameters('%s/imagenet-%s-%d.params' % (save_dir, model_name, opt.num_epochs - 1)) trainer.save_states('%s/imagenet-%s-%d.states' % (save_dir, model_name, opt.num_epochs - 1)) if opt.mode == 'hybrid': net.hybridize(static_alloc=True, static_shape=True) if distillation: teacher.hybridize(static_alloc=True, static_shape=True) train(context)
def evaluate(model, data_iter, context): metric = mx.metric.Accuracy() for _, batch in enumerate(data_iter): data = batch[0].as_in_context(context) label = batch[1].as_in_context(context) output = model(data.astype(args.dtype, copy=False)) metric.update([label], [output]) return metric.get() # Load training and validation data train_data, val_data, train_size = get_mnist_iterator() # Initialize BytePS bps.init() # BytePS: pin context to local rank context = mx.cpu(bps.local_rank()) if args.no_cuda else mx.gpu( bps.local_rank()) num_workers = bps.size() # Build model model = conv_nets() model.cast(args.dtype) # Initialize parameters model.initialize(mx.init.MSRAPrelu(), ctx=context) # if bps.rank() == 0: model.summary(nd.ones((1, 1, 28, 28), ctx=mx.gpu(bps.local_rank()))) model.hybridize()
def test_onebit(self, scaling): bps.init() ctx = mx.gpu(0) net = get_model("resnet18_v2") net.initialize(mx.init.Xavier(), ctx=ctx) net.summary(nd.ones((1, 3, 224, 224), ctx=ctx)) # hyper-params batch_size = 32 optimizer_params = {'momentum': 0, 'wd': 0, 'learning_rate': 0.01} compression_params = { "compressor": "onebit", "scaling": scaling, } trainer = bps.DistributedTrainer(net.collect_params(), "sgd", optimizer_params, compression_params=compression_params) loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() train_data = fake_data(batch_size=batch_size) params = {} for i, param in enumerate(trainer._params): if param.grad_req != 'null': params[i] = param._data[0].asnumpy() for it, batch in tqdm(enumerate(train_data)): data = batch[0].as_in_context(ctx) label = batch[1].as_in_context(ctx) with autograd.record(): output = net(data) loss = loss_fn(output, label) loss.backward() gs = {} xs = {} for i, param in enumerate(trainer._params): if param.grad_req != 'null': gs[i] = param._grad[0].asnumpy() xs[i] = param._data[0].asnumpy() trainer.step(batch_size) for i, param in enumerate(trainer._params): if param.grad_req != "null": g = gs[i] / (batch_size * bps.size()) c = onebit(g, scaling) cs = onebit(c, scaling) c = cs params[i] -= optimizer_params["learning_rate"] * c cnt = 0 tot = 0 for i, param in enumerate(trainer._params): if param.grad_req != "null": x = param._data[0].asnumpy() tot += len(x.flatten()) if not np.allclose(params[i], x, atol=np.finfo( np.float32).eps): diff = np.abs(x.flatten() - params[i].flatten()) idx = np.where(diff > np.finfo(np.float32).eps) cnt += len(idx[0]) assert cnt == 0, "false/tot=%d/%d=%f" % (cnt, tot, cnt / tot)