def main(): global args, best_prec1 args = parser.parse_args() torch.manual_seed(239) # Check the save_dir exists or not if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) model = torch.nn.DataParallel(resnet.__dict__[args.arch]()) model.cuda() statistics = ModelStatistics('Single model') # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] if 'statistics' in checkpoint.keys(): statistics = pickle.loads(checkpoint['statistics']) elif os.path.isfile(os.path.join(args.resume, 'statistics.pickle')): statistics = ModelStatistics.load_from_file( os.path.join(args.resume, 'statistics.pickle')) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_loader = torch.utils.data.DataLoader(datasets.CIFAR10( root='./data', train=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor(), normalize, ]), download=True), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(datasets.CIFAR10( root='./data', train=False, transform=transforms.Compose([ transforms.ToTensor(), normalize, ])), batch_size=128, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() if args.half: model.half() criterion.half() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) def lr_schedule(epoch): factor = 1 if epoch >= 81: factor /= 10 if epoch >= 122: factor /= 10 return factor lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_schedule) if args.arch != 'resnet20': print( 'This code was not intended to be used on resnets other than resnet20' ) if args.arch in ['resnet1202', 'resnet110']: # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up # then switch back. In this setup it will correspond for first epoch. for param_group in optimizer.param_groups: param_group['lr'] = args.lr * 0.1 if args.evaluate: validate(val_loader, model, criterion) return for epoch in range(args.start_epoch, args.epochs): statistics.set_epoch(epoch) # train for one epoch print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) statistics.add('train_begin_timestamp', time.time()) train(train_loader, model, criterion, optimizer, epoch, statistics) lr_scheduler.step() statistics.add('train_end_timestamp', time.time()) # evaluate on validation set statistics.add('validate_begin_timestamp', time.time()) prec1 = validate(val_loader, model, criterion) statistics.add('validate_end_timestamp', time.time()) statistics.add('val_precision', prec1) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) if epoch > 0 and epoch % args.save_every == 0: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'statistics': pickle.dumps(statistics) }, is_best, filename=os.path.join(args.save_dir, 'checkpoint.th')) save_checkpoint( { 'state_dict': model.state_dict(), 'best_prec1': best_prec1 }, is_best, filename=os.path.join(args.save_dir, 'model.th')) statistics.dump_to_file( os.path.join(args.save_dir, 'statistics.pickle'))
async def main(): global args, best_prec1 args = parser.parse_args() torch.manual_seed(239) print('Consensus agent: {}'.format(args.agent_token)) convergence_eps = 1e-4 agent = ConsensusAgent(args.agent_token, args.agent_host, args.agent_port, args.master_host, args.master_port, convergence_eps=convergence_eps, debug=True if args.debug else False) agent_serve_task = asyncio.create_task(agent.serve_forever()) print('{}: Created serving task'.format(args.agent_token)) # Check the save_dir exists or not args.save_dir = os.path.join(args.save_dir, str(args.agent_token)) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) model = torch.nn.DataParallel(resnet.__dict__[args.arch]()) model.cuda() statistics = ModelStatistics(args.agent_token) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): if args.logging: print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] if 'statistics' in checkpoint.keys(): statistics = pickle.loads(checkpoint['statistics']) elif os.path.isfile(os.path.join(args.resume, 'statistics.pickle')): statistics = ModelStatistics.load_from_file( os.path.join(args.resume, 'statistics.pickle')) model.load_state_dict(checkpoint['state_dict']) if args.logging: print("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch'])) else: if args.logging: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) dataset_path = os.path.join('./data/', str(args.agent_token)) train_dataset = datasets.CIFAR10(root=dataset_path, train=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor(), normalize, ]), download=True) size_per_agent = len(train_dataset) // args.total_agents train_indices = list( range(args.agent_token * size_per_agent, min(len(train_dataset), (args.agent_token + 1) * size_per_agent))) if args.target_split: train_indices = list(range( len(train_dataset)))[train_dataset.targets == args.agent_token] print('Target split: {} samples for agent {}'.format( len(train_indices), args.agent_token)) from torch.utils.data.sampler import SubsetRandomSampler train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=False, # !!!!! num_workers=args.workers, pin_memory=True, sampler=SubsetRandomSampler(train_indices)) val_loader = torch.utils.data.DataLoader(datasets.CIFAR10( root=dataset_path, train=False, transform=transforms.Compose([ transforms.ToTensor(), normalize, ])), batch_size=128, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() if args.half: model.half() criterion.half() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) def lr_schedule(epoch): factor = args.total_agents if epoch >= 81: factor /= 10 if epoch >= 122: factor /= 10 return factor lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_schedule) if args.arch != 'resnet20': print( 'This code was not intended to be used on resnets other than resnet20' ) if args.arch in ['resnet1202', 'resnet110']: # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up # then switch back. In this setup it will correspond for first epoch. for param_group in optimizer.param_groups: param_group['lr'] = args.lr * 0.1 if args.evaluate: validate(val_loader, model, criterion) return def dump_params(model): return torch.cat([ v.to(torch.float32).view(-1) for k, v in model.state_dict().items() ]).cpu().numpy() def load_params(model, params): st = model.state_dict() used_params = 0 for k in st.keys(): cnt_params = st[k].numel() st[k] = torch.Tensor(params[used_params:used_params + cnt_params]).view(st[k].shape)\ .to(st[k].dtype).to(st[k].device) used_params += cnt_params model.load_state_dict(st) async def run_averaging(): params = dump_params(model) params = await agent.run_once(params) load_params(model, params) if args.logging: print('Starting initial averaging...') params = dump_params(model) params = await agent.run_round(params, 1.0 if args.init_leader else 0.0) load_params(model, params) if args.logging: print('Initial averaging completed!') for epoch in range(args.start_epoch, args.epochs): statistics.set_epoch(epoch) # train for one epoch if args.logging: print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) statistics.add('train_begin_timestamp', time.time()) await train(train_loader, model, criterion, optimizer, epoch, statistics, run_averaging) lr_scheduler.step() statistics.add('train_end_timestamp', time.time()) # evaluate on validation set statistics.add('validate_begin_timestamp', time.time()) prec1 = validate(val_loader, model, criterion) statistics.add('validate_end_timestamp', time.time()) statistics.add('val_precision', prec1) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) if epoch > 0 and epoch % args.save_every == 0: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'statistics': pickle.dumps(statistics) }, is_best, filename=os.path.join(args.save_dir, 'checkpoint.th')) save_checkpoint( { 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=os.path.join(args.save_dir, 'model.th')) statistics.dump_to_file( os.path.join(args.save_dir, 'statistics.pickle'))