def test_input_singleton(): class One(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(1, 1) def forward(self, only_a): a, = only_a return (self.fc(a), ) model = nn.Sequential(One()) model = GPipe(model, balance=[1], devices=['cpu'], chunks=2) a = torch.rand(10, 1, requires_grad=True) a_out, = model((a, )) loss = a_out.mean() loss.backward() assert all(p.grad is not None for p in model.parameters()) assert a.grad is not None
def test_identicalness(): def sum_grad(parameters): return sum([p.grad.sum() for p in parameters if p.grad is not None]) def zero_grad(parameters): for p in parameters: p.grad = None inputs = torch.rand(8, 1) model = nn.Sequential( nn.Linear(1, 2), nn.Linear(2, 4), nn.Linear(4, 2), nn.Linear(2, 1), ) # Without GPipe outputs = model(inputs) loss = outputs.mean() loss.backward() grad_without_gpipe = sum_grad(model.parameters()) zero_grad(model.parameters()) # With GPipe model = GPipe(model, [2, 2], devices=['cpu', 'cpu'], chunks=4) outputs = model(inputs) loss = outputs.mean() loss.backward() grad_with_gpipe = sum_grad(model.parameters()) # Both grads should be identical. assert torch.allclose(grad_with_gpipe, grad_without_gpipe)
def main(): parser = argparse.ArgumentParser(description='D-DNN imagenet benchmark') parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') # Value of args.synthetic_data may seem confusing, but those values # come from bash and there 0=true and all else =false parser.add_argument('-s', '--synthetic_data', type=int, default=0, help="Use synthetic data") args = parser.parse_args() torch.manual_seed(1) torch.cuda.manual_seed(1) cudnn.benchmark = True #--------------------------------------------------------------------------------- # Move model to GPU. print("=> creating model '{}'".format(args.arch)) model = model_names[args.arch].cuda() partitions = torch.cuda.device_count() if args.synthetic_data == -1: sample = torch.empty(batch_size, 3, 512, 512) else: sample = torch.empty(batch_size, 3, 224, 224) balance = balance_by_time(partitions, model, sample) model = GPipe(model, balance, chunks=microbatches) #--------------------------------------------------------------------------------- devices = list(model.devices) in_device = devices[0] out_device = devices[-1] torch.cuda.set_device(in_device) throughputs = [] elapsed_times = [] #--------------------------------------------------------------------------------- # define optimizer optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) #--------------------------------------------------------------------------------- normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_comp = [ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ] val_comp = [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ] if args.synthetic_data == -1: # Load highres data traindir = datadir + '/HIGHRES/train' valdir = datadir + '/HIGHRES/val' train_comp = [transforms.ToTensor(), normalize] val_comp = [transforms.ToTensor(), normalize] elif args.synthetic_data: # Load normal data traindir = datadir + '/train' valdir = datadir + '/val' else: # Load synthetic data traindir = datadir + '/IMAGENET/train' valdir = datadir + '/IMAGENET/val' train_loader = torch.utils.data.DataLoader(datasets.ImageFolder( traindir, transforms.Compose(train_comp)), batch_size=batch_size, shuffle=True, num_workers=cores_gpu, pin_memory=True) val_loader = torch.utils.data.DataLoader(datasets.ImageFolder( valdir, transforms.Compose(val_comp)), batch_size=batch_size, shuffle=True, num_workers=cores_gpu, pin_memory=True) #--------------------------------------------------------------------------------- for epoch in range(epochs): throughput, elapsed_time = run_epoch(train_loader, val_loader, model, optimizer, epoch, args, in_device, out_device) throughputs.append(throughput) elapsed_times.append(elapsed_time) _, valid_accuracy = evaluate(val_loader, model, args, in_device, out_device) n = len(throughputs) throughput = sum(throughputs) / n if n > 0 else 0.0 elapsed_time = sum(elapsed_times) / n if n > 0 else 0.0 print('valid accuracy: %.4f | %.3f samples/sec, %.3f sec/epoch (average)' '' % (valid_accuracy, throughput, elapsed_time))
def test_parameters(): model = nn.Sequential(nn.Linear(1, 1)) gpipe = GPipe(model, balance=[1], devices=['cpu'], chunks=1) assert list(gpipe.parameters()) != []
partitions = torch.cuda.device_count() sample = torch.empty(batch_size, 1, 28, 28) balance = balance_by_time(partitions, model, sample) model = GPipe(model, balance, chunks=microbatches) #--------------------------------------------------------------------------------- devices = list(model.devices) in_device = devices[0] out_device = devices[-1] torch.cuda.set_device(in_device) throughputs = [] elapsed_times = [] optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) for epoch in range(epochs): throughput, elapsed_time = run_epoch(args, model, in_device, out_device, train_loader, test_loader, epoch, optimizer) throughputs.append(throughput) elapsed_times.append(elapsed_time) _, valid_accuracy = evaluate(test_loader, in_device, out_device, model) n = len(throughputs) throughput = sum(throughputs) / n if n > 0 else 0.0 elapsed_time = sum(elapsed_times) / n if n > 0 else 0.0
nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), View(), nn.Linear(16 * 5 * 5, 120), nn.ReLU(inplace=True), nn.Linear(120, 84), nn.ReLU(inplace=True), nn.Linear(84, 10)) # init net = GPipe(model, balance=[6, 6], chunks=2) print(len(net)) print('this is the end of defining model') print("time starts") starttime = time.time() # Define a Loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # Train the network for epoch in range(20): # loop over the dataset multiple times running_loss = 0.0 device0 = torch.device("cuda:0") device1 = torch.device("cuda:1") for i, data in enumerate(trainloader, start=0): # print(i, data) # get the inputs; data is a list of [inputs, labels] inputs, labels = data optimizer.zero_grad() inputs = inputs.to(device0) labels = labels.to(device1) outputs = net(inputs)
class GPipeModel(object): def __init__(self, model_name, model_path, gradient_clip_value=5.0, device_ids=None, **kwargs): gpipe_model = nn.Sequential(gpipe_encoder(model_name, **kwargs), gpipe_decoder(model_name, **kwargs)) self.model = GPipe(gpipe_model, balance=[1, 1], chunks=2) self.in_device = self.model.devices[0] self.out_device = self.model.devices[-1] self.loss_fn = nn.BCEWithLogitsLoss() self.model_path, self.state = model_path, {} os.makedirs(os.path.split(self.model_path)[0], exist_ok=True) self.gradient_clip_value, self.gradient_norm_queue = gradient_clip_value, deque( [np.inf], maxlen=5) self.optimizer = None def train_step(self, train_x: torch.Tensor, train_y: torch.Tensor): self.optimizer.zero_grad() self.model.train() scores = self.model(train_x) loss = self.loss_fn(scores, train_y) loss.backward() self.clip_gradient() self.optimizer.step(closure=None) return loss.item() def predict_step(self, data_x: torch.Tensor, k: int): self.model.eval() with torch.no_grad(): scores, labels = torch.topk(self.model(data_x), k) return torch.sigmoid(scores).cpu(), labels.cpu() def get_optimizer(self, **kwargs): self.optimizer = DenseSparseAdam(self.model.parameters(), **kwargs) def train(self, train_loader: DataLoader, valid_loader: DataLoader, opt_params: Optional[Mapping] = None, nb_epoch=100, step=100, k=5, early=100, verbose=True, swa_warmup=None, **kwargs): self.get_optimizer(**({} if opt_params is None else opt_params)) global_step, best_n5, e = 0, 0.0, 0 print_loss = 0.0 # for epoch_idx in range(nb_epoch): if epoch_idx == swa_warmup: self.swa_init() for i, (train_x, train_y) in enumerate(train_loader, 1): global_step += 1 loss = self.train_step( train_x.to(self.in_device, non_blocking=True), train_y.to(self.out_device, non_blocking=True)) print_loss += loss # if global_step % step == 0: self.swa_step() self.swap_swa_params() ## labels = [] valid_loss = 0.0 self.model.eval() with torch.no_grad(): for (valid_x, valid_y) in valid_loader: logits = self.model( valid_x.to(self.in_device, non_blocking=True)) valid_loss += self.loss_fn( logits, valid_y.to(self.out_device, non_blocking=True)).item() scores, tmp = torch.topk(logits, k) labels.append(tmp.cpu()) valid_loss /= len(valid_loader) labels = np.concatenate(labels) ## # labels = np.concatenate([self.predict_step(valid_x, k)[1] for valid_x in valid_loader]) targets = valid_loader.dataset.data_y p5, n5 = get_p_5(labels, targets), get_n_5(labels, targets) if n5 > best_n5: self.save_model(epoch_idx > 3 * swa_warmup) best_n5, e = n5, 0 else: e += 1 if early is not None and e > early: return self.swap_swa_params() if verbose: log_msg = '%d %d train loss: %.7f valid loss: %.7f P@5: %.5f N@5: %.5f early stop: %d' % \ (epoch_idx, i * train_loader.batch_size, print_loss / step, valid_loss, round(p5, 5), round(n5, 5), e) logger.info(log_msg) print_loss = 0.0 def predict(self, data_loader: DataLoader, k=100, desc='Predict', **kwargs): self.load_model() scores_list, labels_list = zip(*( self.predict_step(data_x.to(self.in_device, non_blocking=True), k) for data_x in tqdm(data_loader, desc=desc, leave=False))) return np.concatenate(scores_list), np.concatenate(labels_list) def save_model(self, last_epoch): if not last_epoch: return for trial in range(5): try: torch.save(self.model.state_dict(), self.model_path) break except: print('saving failed') def load_model(self): self.model.load_state_dict(torch.load(self.model_path)) def clip_gradient(self): if self.gradient_clip_value is not None: max_norm = max(self.gradient_norm_queue) total_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm * self.gradient_clip_value) self.gradient_norm_queue.append( min(total_norm, max_norm * 2.0, 1.0)) if total_norm > max_norm * self.gradient_clip_value: logger.warn( F'Clipping gradients with total norm {round(total_norm, 5)} ' F'and max norm {round(max_norm, 5)}') def swa_init(self): if 'swa' not in self.state: logger.info('SWA Initializing') swa_state = self.state['swa'] = {'models_num': 1} for n, p in self.model.named_parameters(): swa_state[n] = p.data.cpu().detach() def swa_step(self): if 'swa' in self.state: swa_state = self.state['swa'] swa_state['models_num'] += 1 beta = 1.0 / swa_state['models_num'] with torch.no_grad(): for n, p in self.model.named_parameters(): swa_state[n].mul_(1.0 - beta).add_(beta, p.data.cpu()) def swap_swa_params(self): if 'swa' in self.state: swa_state = self.state['swa'] for n, p in self.model.named_parameters(): gpu_id = p.get_device() p.data, swa_state[n] = swa_state[n], p.data.cpu() # p.data = p.data.cuda(gpu_id) def disable_swa(self): if 'swa' in self.state: del self.state['swa']
balance = balance_by_size(partitions, model, sample) else: raise NotImplementedError("Unsupport value specified for 'balance_by' argument") print("== Wrapping model as GPipe model ==") model = GPipe(model, balance, chunks=args.num_microbatches) #--------------------------------------------------------------------------------- # Specify input and output to the correct device devices = list(model.devices) in_device = devices[0] out_device = devices[-1] throughputs = [] elapsed_times = [] optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) print("== Start training ==") for epoch in range(epochs): throughput, elapsed_time = run_epoch(args, model, in_device, out_device, train_loader, test_loader, epoch, optimizer) throughputs.append(throughput) elapsed_times.append(elapsed_time) _, valid_accuracy = evaluate(test_loader, in_device, out_device, model) n = len(throughputs) throughput = sum(throughputs) / n if n > 0 else 0.0 elapsed_time = sum(elapsed_times) / n if n > 0 else 0.0 print('valid accuracy: %.4f | %.3f samples/sec, %.3f sec/epoch (average)' '' % (valid_accuracy, throughput, elapsed_time))