def __init__(self, destination_path, topology): self.destination_path = destination_path self.agents = list( set([uv[0] for uv in topology] + [uv[1] for uv in topology])) self.stats = ModelStatistics('MASTER TELEMETRY', save_path=destination_path) self.agent_params_by_iter = dict() self.agent_general_info = dict()
def __init__(self, destination_path, topology, resume=False): self.destination_path = destination_path self.agents = list( set([uv[0] for uv in topology] + [uv[1] for uv in topology])) self.stats = ModelStatistics.load_from_file(destination_path) if resume \ else ModelStatistics('MASTER TELEMETRY', save_path=destination_path) self.agent_params_by_iter = dict() self.agent_general_info = dict()
def main(): global args, best_prec1 args = parser.parse_args() torch.manual_seed(239) # Check the save_dir exists or not if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) model = torch.nn.DataParallel(resnet.__dict__[args.arch]()) model.cuda() statistics = ModelStatistics('Single model') # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] if 'statistics' in checkpoint.keys(): statistics = pickle.loads(checkpoint['statistics']) elif os.path.isfile(os.path.join(args.resume, 'statistics.pickle')): statistics = ModelStatistics.load_from_file( os.path.join(args.resume, 'statistics.pickle')) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_loader = torch.utils.data.DataLoader(datasets.CIFAR10( root='./data', train=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor(), normalize, ]), download=True), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(datasets.CIFAR10( root='./data', train=False, transform=transforms.Compose([ transforms.ToTensor(), normalize, ])), batch_size=128, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() if args.half: model.half() criterion.half() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) def lr_schedule(epoch): factor = 1 if epoch >= 81: factor /= 10 if epoch >= 122: factor /= 10 return factor lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_schedule) if args.arch != 'resnet20': print( 'This code was not intended to be used on resnets other than resnet20' ) if args.arch in ['resnet1202', 'resnet110']: # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up # then switch back. In this setup it will correspond for first epoch. for param_group in optimizer.param_groups: param_group['lr'] = args.lr * 0.1 if args.evaluate: validate(val_loader, model, criterion) return for epoch in range(args.start_epoch, args.epochs): statistics.set_epoch(epoch) # train for one epoch print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) statistics.add('train_begin_timestamp', time.time()) train(train_loader, model, criterion, optimizer, epoch, statistics) lr_scheduler.step() statistics.add('train_end_timestamp', time.time()) # evaluate on validation set statistics.add('validate_begin_timestamp', time.time()) prec1 = validate(val_loader, model, criterion) statistics.add('validate_end_timestamp', time.time()) statistics.add('val_precision', prec1) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) if epoch > 0 and epoch % args.save_every == 0: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'statistics': pickle.dumps(statistics) }, is_best, filename=os.path.join(args.save_dir, 'checkpoint.th')) save_checkpoint( { 'state_dict': model.state_dict(), 'best_prec1': best_prec1 }, is_best, filename=os.path.join(args.save_dir, 'model.th')) statistics.dump_to_file( os.path.join(args.save_dir, 'statistics.pickle'))
async def main(): global args, best_prec1 args = parser.parse_args() torch.manual_seed(239) print('Consensus agent: {}'.format(args.agent_token)) convergence_eps = 1e-4 agent = ConsensusAgent(args.agent_token, args.agent_host, args.agent_port, args.master_host, args.master_port, convergence_eps=convergence_eps, debug=True if args.debug else False) agent_serve_task = asyncio.create_task(agent.serve_forever()) print('{}: Created serving task'.format(args.agent_token)) # Check the save_dir exists or not args.save_dir = os.path.join(args.save_dir, str(args.agent_token)) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) model = torch.nn.DataParallel(resnet.__dict__[args.arch]()) model.cuda() statistics = ModelStatistics(args.agent_token) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): if args.logging: print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] if 'statistics' in checkpoint.keys(): statistics = pickle.loads(checkpoint['statistics']) elif os.path.isfile(os.path.join(args.resume, 'statistics.pickle')): statistics = ModelStatistics.load_from_file( os.path.join(args.resume, 'statistics.pickle')) model.load_state_dict(checkpoint['state_dict']) if args.logging: print("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch'])) else: if args.logging: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) dataset_path = os.path.join('./data/', str(args.agent_token)) train_dataset = datasets.CIFAR10(root=dataset_path, train=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor(), normalize, ]), download=True) size_per_agent = len(train_dataset) // args.total_agents train_indices = list( range(args.agent_token * size_per_agent, min(len(train_dataset), (args.agent_token + 1) * size_per_agent))) if args.target_split: train_indices = list(range( len(train_dataset)))[train_dataset.targets == args.agent_token] print('Target split: {} samples for agent {}'.format( len(train_indices), args.agent_token)) from torch.utils.data.sampler import SubsetRandomSampler train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=False, # !!!!! num_workers=args.workers, pin_memory=True, sampler=SubsetRandomSampler(train_indices)) val_loader = torch.utils.data.DataLoader(datasets.CIFAR10( root=dataset_path, train=False, transform=transforms.Compose([ transforms.ToTensor(), normalize, ])), batch_size=128, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() if args.half: model.half() criterion.half() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) def lr_schedule(epoch): factor = args.total_agents if epoch >= 81: factor /= 10 if epoch >= 122: factor /= 10 return factor lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_schedule) if args.arch != 'resnet20': print( 'This code was not intended to be used on resnets other than resnet20' ) if args.arch in ['resnet1202', 'resnet110']: # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up # then switch back. In this setup it will correspond for first epoch. for param_group in optimizer.param_groups: param_group['lr'] = args.lr * 0.1 if args.evaluate: validate(val_loader, model, criterion) return def dump_params(model): return torch.cat([ v.to(torch.float32).view(-1) for k, v in model.state_dict().items() ]).cpu().numpy() def load_params(model, params): st = model.state_dict() used_params = 0 for k in st.keys(): cnt_params = st[k].numel() st[k] = torch.Tensor(params[used_params:used_params + cnt_params]).view(st[k].shape)\ .to(st[k].dtype).to(st[k].device) used_params += cnt_params model.load_state_dict(st) async def run_averaging(): params = dump_params(model) params = await agent.run_once(params) load_params(model, params) if args.logging: print('Starting initial averaging...') params = dump_params(model) params = await agent.run_round(params, 1.0 if args.init_leader else 0.0) load_params(model, params) if args.logging: print('Initial averaging completed!') for epoch in range(args.start_epoch, args.epochs): statistics.set_epoch(epoch) # train for one epoch if args.logging: print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) statistics.add('train_begin_timestamp', time.time()) await train(train_loader, model, criterion, optimizer, epoch, statistics, run_averaging) lr_scheduler.step() statistics.add('train_end_timestamp', time.time()) # evaluate on validation set statistics.add('validate_begin_timestamp', time.time()) prec1 = validate(val_loader, model, criterion) statistics.add('validate_end_timestamp', time.time()) statistics.add('val_precision', prec1) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) if epoch > 0 and epoch % args.save_every == 0: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'statistics': pickle.dumps(statistics) }, is_best, filename=os.path.join(args.save_dir, 'checkpoint.th')) save_checkpoint( { 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=os.path.join(args.save_dir, 'model.th')) statistics.dump_to_file( os.path.join(args.save_dir, 'statistics.pickle'))
def plot_loop(names, paths, title, save=None, param_dev=None): finish = Finish() signal.signal(signal.SIGINT, finish) plt.rcParams['font.size'] = 18 plt.rcParams['axes.facecolor'] = 'white' plt.rcParams['figure.facecolor'] = 'white' while not finish.finished(): stats = {name: ModelStatistics.load_from_file(path) for name, path in zip(names, paths)} param_dev_stats = ModelStatistics.load_from_file(param_dev) if param_dev else None if param_dev_stats: fig, (plt_loss, plt_param_dev, plt_val_acc) = plt.subplots(3, 1) fig.set_size_inches(18, 20) else: plt_param_dev = None fig, (plt_loss, plt_val_acc) = plt.subplots(2, 1) fig.set_size_inches(18, 20 * 0.7) fig.suptitle(title, fontsize=20) plt_loss.set_ylabel('Loss (local)') plt_loss.set_yscale('log') plt_loss.set_xlabel('Epoch') if param_dev_stats: plt_param_dev.set_ylabel('Parameter deviation (coef. of variation)') plt_param_dev.set_xlabel('Epoch') plt_param_dev.set_yscale('log') plt_val_acc.set_ylabel('Validation Accuracy, %') plt_val_acc.set_xlabel('Epoch') for label, stat in stats.items(): loss = stat.crop('train_loss') val_acc = stat.crop('val_precision') fmt = {} if label.lower().find('consensus') != -1: fmt['linestyle'] = 'dashed' fmt['linewidth'] = 1.1 else: fmt['linestyle'] = None fmt['linewidth'] = 1.5 plt_loss.plot(range(len(loss)), loss, label=label, **fmt) plt_val_acc.plot(range(len(val_acc)), val_acc, label=label + ' ({})'.format(val_acc[-1]), **fmt) if param_dev_stats: try: telemetries_per_epoch = next(iter(param_dev_stats.crop('telemetries_per_epoch')[0].values())) try: deviation = param_dev_stats.crop('coef_of_var') plt_param_dev.plot([b / telemetries_per_epoch for b in range(len(deviation))], deviation, label='max') except: pass try: cv_pctls = param_dev_stats.crop('abs_coef_of_var_percentiles') except: cv_pctls = param_dev_stats.crop('coef_of_var_percentiles') grouped_by_pcts = dict() for record in cv_pctls: for (pct, val) in record: if pct not in grouped_by_pcts.keys(): grouped_by_pcts[pct] = [] grouped_by_pcts[pct].append(val) for pct, vals in reversed(list(grouped_by_pcts.items())): if pct < 75 or 99 < pct: continue plt_param_dev.plot([b / telemetries_per_epoch for b in range(len(vals))], vals, label='percentile={}'.format(pct)) except: pass plt_loss.legend() plt_val_acc.legend() if param_dev_stats: plt_param_dev.legend() fig.tight_layout() plt.close(fig) clear_output(wait=True) display(fig) if save is not None: fig.savefig(save) time.sleep(5.0)
async def main(cfg): best_prec1 = 0 torch.manual_seed(239) print('Consensus agent: {}'.format(cfg.agent_token)) consensus_specific = ConsensusSpecific(cfg) consensus_specific.init_consensus() # Check the save_dir exists or not cfg.save_dir = os.path.join(cfg.save_dir, str(cfg.agent_token)) if not os.path.exists(cfg.save_dir): os.makedirs(cfg.save_dir) model = torch.nn.DataParallel(resnet.__dict__[cfg.arch]()) model.cuda() print('{}: Created model'.format(cfg.agent_token)) statistics = ModelStatistics(cfg.agent_token) # optionally resume from a checkpoint if cfg.do_resume: checkpoint_path = os.path.join(cfg.save_dir, 'checkpoint.th') if os.path.isfile(checkpoint_path): if cfg.logging: print("=> loading checkpoint '{}'".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) cfg.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] if 'statistics' in checkpoint.keys(): statistics = pickle.loads(checkpoint['statistics']) elif os.path.isfile(os.path.join(cfg.save_dir, 'statistics.pickle')): statistics = ModelStatistics.load_from_file( os.path.join(cfg.save_dir, 'statistics.pickle')) model.load_state_dict(checkpoint['state_dict']) if cfg.logging: print("=> loaded checkpoint '{}' (epoch {})".format( checkpoint_path, checkpoint['epoch'])) else: if cfg.logging: print("=> no checkpoint found at '{}'".format(checkpoint_path)) cudnn.benchmark = True print('{}: Loading dataset...'.format(cfg.agent_token)) train_loader = get_agent_train_loader(cfg.agent_token, cfg.batch_size) print('{}: loaded {} batches for train'.format(cfg.agent_token, len(train_loader))) val_loader = None if cfg.no_validation else get_agent_val_loader( cfg.agent_token) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() if cfg.half: model.half() criterion.half() optimizer = torch.optim.SGD(model.parameters(), cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay) def lr_schedule(epoch): if cfg.use_lsr and epoch < cfg.warmup: factor = np.power(cfg.total_agents, epoch / cfg.warmup) else: factor = cfg.total_agents if cfg.use_lsr else 1.0 if epoch >= 81: factor /= 10 if epoch >= 122: factor /= 10 return factor lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_schedule) if cfg.arch != 'resnet20': print( 'This code was not intended to be used on resnets other than resnet20' ) if cfg.arch in ['resnet1202', 'resnet110']: # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up # then switch back. In this setup it will correspond for first epoch. for param_group in optimizer.param_groups: param_group['lr'] = cfg.lr * 0.1 if cfg.evaluate: validate(cfg, val_loader, model, criterion) return await consensus_specific.agent.send_telemetry( TelemetryAgentGeneralInfo( telemetries_per_epoch=cfg.telemetry_freq_per_epoch)) for epoch in range(0, cfg.start_epoch): lr_scheduler.step() for epoch in range(cfg.start_epoch, cfg.epochs): # train for one epoch if cfg.logging: print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) statistics.add('train_begin_timestamp', time.time()) await train(consensus_specific, train_loader, model, criterion, optimizer, epoch, statistics) lr_scheduler.step() statistics.add('train_end_timestamp', time.time()) # evaluate on validation set statistics.add('validate_begin_timestamp', time.time()) prec1 = validate(cfg, val_loader, model, criterion) statistics.add('validate_end_timestamp', time.time()) statistics.add('val_precision', prec1) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) if epoch > 0 and epoch % cfg.save_every == 0: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'statistics': pickle.dumps(statistics) }, is_best, filename=os.path.join(cfg.save_dir, 'checkpoint.th')) save_checkpoint( { 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=os.path.join(cfg.save_dir, 'model.th')) statistics.dump_to_file(os.path.join(cfg.save_dir, 'statistics.pickle')) consensus_specific.stop_consensus()
class ResNet20TelemetryProcessor(TelemetryProcessor): def __init__(self, destination_path, topology): self.destination_path = destination_path self.agents = list( set([uv[0] for uv in topology] + [uv[1] for uv in topology])) self.stats = ModelStatistics('MASTER TELEMETRY', save_path=destination_path) self.agent_params_by_iter = dict() self.agent_general_info = dict() def process(self, token, payload): if isinstance(payload, TelemetryModelParameters): if payload.batch_number not in self.agent_params_by_iter.keys(): self.agent_params_by_iter[payload.batch_number] = dict() self.agent_params_by_iter[ payload.batch_number][token] = payload.parameters if len(self.agent_params_by_iter[payload.batch_number]) == len( self.agents): params = self.agent_params_by_iter[payload.batch_number] avg_params = np.mean([params[agent] for agent in self.agents], axis=0) deviation_params = { agent: params[agent] - avg_params for agent in self.agents } self.stats.add( 'param_deviation_L1', { agent: np.linalg.norm(deviation_params[agent], ord=1) for agent in self.agents }) self.stats.add( 'param_deviation_L2', { agent: np.linalg.norm(deviation_params[agent], ord=2) for agent in self.agents }) self.stats.add( 'param_deviation_Linf', { agent: np.linalg.norm(deviation_params[agent], ord=np.inf) for agent in self.agents }) arr_params = np.array([params[agent] for agent in self.agents]) max_cv = np.linalg.norm(np.std(arr_params, axis=0) / np.mean(arr_params, axis=0), ord=np.inf) self.stats.add('coef_of_var', max_cv) self.stats.dump_to_file() del self.agent_params_by_iter[payload.batch_number] elif isinstance(payload, TelemetryAgentGeneralInfo): self.agent_general_info[token] = payload if len(self.agent_general_info) == len(self.agents): self.stats.add( 'batches_per_epoch', { agent: self.agent_general_info[agent].batches_per_epoch for agent in self.agents }) self.stats.dump_to_file() else: raise ValueError( f'Got unsupported payload from {token}: {payload!r}')