예제 #1
0
    def __init__(self, destination_path, topology):
        self.destination_path = destination_path
        self.agents = list(
            set([uv[0] for uv in topology] + [uv[1] for uv in topology]))

        self.stats = ModelStatistics('MASTER TELEMETRY',
                                     save_path=destination_path)
        self.agent_params_by_iter = dict()
        self.agent_general_info = dict()
    def __init__(self, destination_path, topology, resume=False):
        self.destination_path = destination_path
        self.agents = list(
            set([uv[0] for uv in topology] + [uv[1] for uv in topology]))

        self.stats = ModelStatistics.load_from_file(destination_path) if resume \
            else ModelStatistics('MASTER TELEMETRY', save_path=destination_path)

        self.agent_params_by_iter = dict()
        self.agent_general_info = dict()
예제 #3
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    torch.manual_seed(239)

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
    model.cuda()

    statistics = ModelStatistics('Single model')

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            if 'statistics' in checkpoint.keys():
                statistics = pickle.loads(checkpoint['statistics'])
            elif os.path.isfile(os.path.join(args.resume,
                                             'statistics.pickle')):
                statistics = ModelStatistics.load_from_file(
                    os.path.join(args.resume, 'statistics.pickle'))
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        root='./data',
        train=True,
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]),
        download=True),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        root='./data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=128,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.half:
        model.half()
        criterion.half()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    def lr_schedule(epoch):
        factor = 1
        if epoch >= 81:
            factor /= 10
        if epoch >= 122:
            factor /= 10
        return factor

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                     lr_lambda=lr_schedule)

    if args.arch != 'resnet20':
        print(
            'This code was not intended to be used on resnets other than resnet20'
        )

    if args.arch in ['resnet1202', 'resnet110']:
        # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
        # then switch back. In this setup it will correspond for first epoch.
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr * 0.1

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        statistics.set_epoch(epoch)
        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        statistics.add('train_begin_timestamp', time.time())
        train(train_loader, model, criterion, optimizer, epoch, statistics)
        lr_scheduler.step()
        statistics.add('train_end_timestamp', time.time())

        # evaluate on validation set
        statistics.add('validate_begin_timestamp', time.time())
        prec1 = validate(val_loader, model, criterion)
        statistics.add('validate_end_timestamp', time.time())
        statistics.add('val_precision', prec1)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'statistics': pickle.dumps(statistics)
                },
                is_best,
                filename=os.path.join(args.save_dir, 'checkpoint.th'))

        save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1
            },
            is_best,
            filename=os.path.join(args.save_dir, 'model.th'))
        statistics.dump_to_file(
            os.path.join(args.save_dir, 'statistics.pickle'))
async def main():
    global args, best_prec1
    args = parser.parse_args()

    torch.manual_seed(239)

    print('Consensus agent: {}'.format(args.agent_token))
    convergence_eps = 1e-4
    agent = ConsensusAgent(args.agent_token,
                           args.agent_host,
                           args.agent_port,
                           args.master_host,
                           args.master_port,
                           convergence_eps=convergence_eps,
                           debug=True if args.debug else False)
    agent_serve_task = asyncio.create_task(agent.serve_forever())
    print('{}: Created serving task'.format(args.agent_token))

    # Check the save_dir exists or not
    args.save_dir = os.path.join(args.save_dir, str(args.agent_token))
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
    model.cuda()

    statistics = ModelStatistics(args.agent_token)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            if args.logging:
                print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            if 'statistics' in checkpoint.keys():
                statistics = pickle.loads(checkpoint['statistics'])
            elif os.path.isfile(os.path.join(args.resume,
                                             'statistics.pickle')):
                statistics = ModelStatistics.load_from_file(
                    os.path.join(args.resume, 'statistics.pickle'))
            model.load_state_dict(checkpoint['state_dict'])
            if args.logging:
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.evaluate, checkpoint['epoch']))
        else:
            if args.logging:
                print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    dataset_path = os.path.join('./data/', str(args.agent_token))
    train_dataset = datasets.CIFAR10(root=dataset_path,
                                     train=True,
                                     transform=transforms.Compose([
                                         transforms.RandomHorizontalFlip(),
                                         transforms.RandomCrop(32, 4),
                                         transforms.ToTensor(),
                                         normalize,
                                     ]),
                                     download=True)

    size_per_agent = len(train_dataset) // args.total_agents
    train_indices = list(
        range(args.agent_token * size_per_agent,
              min(len(train_dataset),
                  (args.agent_token + 1) * size_per_agent)))

    if args.target_split:
        train_indices = list(range(
            len(train_dataset)))[train_dataset.targets == args.agent_token]
        print('Target split: {} samples for agent {}'.format(
            len(train_indices), args.agent_token))

    from torch.utils.data.sampler import SubsetRandomSampler
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,  # !!!!!
        num_workers=args.workers,
        pin_memory=True,
        sampler=SubsetRandomSampler(train_indices))

    val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        root=dataset_path,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=128,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.half:
        model.half()
        criterion.half()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    def lr_schedule(epoch):
        factor = args.total_agents
        if epoch >= 81:
            factor /= 10
        if epoch >= 122:
            factor /= 10
        return factor

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                     lr_lambda=lr_schedule)

    if args.arch != 'resnet20':
        print(
            'This code was not intended to be used on resnets other than resnet20'
        )

    if args.arch in ['resnet1202', 'resnet110']:
        # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
        # then switch back. In this setup it will correspond for first epoch.
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr * 0.1

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    def dump_params(model):
        return torch.cat([
            v.to(torch.float32).view(-1)
            for k, v in model.state_dict().items()
        ]).cpu().numpy()

    def load_params(model, params):
        st = model.state_dict()
        used_params = 0
        for k in st.keys():
            cnt_params = st[k].numel()
            st[k] = torch.Tensor(params[used_params:used_params + cnt_params]).view(st[k].shape)\
                .to(st[k].dtype).to(st[k].device)
            used_params += cnt_params
        model.load_state_dict(st)

    async def run_averaging():
        params = dump_params(model)
        params = await agent.run_once(params)
        load_params(model, params)

    if args.logging:
        print('Starting initial averaging...')

    params = dump_params(model)
    params = await agent.run_round(params, 1.0 if args.init_leader else 0.0)
    load_params(model, params)

    if args.logging:
        print('Initial averaging completed!')

    for epoch in range(args.start_epoch, args.epochs):
        statistics.set_epoch(epoch)
        # train for one epoch
        if args.logging:
            print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        statistics.add('train_begin_timestamp', time.time())
        await train(train_loader, model, criterion, optimizer, epoch,
                    statistics, run_averaging)
        lr_scheduler.step()
        statistics.add('train_end_timestamp', time.time())

        # evaluate on validation set
        statistics.add('validate_begin_timestamp', time.time())
        prec1 = validate(val_loader, model, criterion)
        statistics.add('validate_end_timestamp', time.time())
        statistics.add('val_precision', prec1)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'statistics': pickle.dumps(statistics)
                },
                is_best,
                filename=os.path.join(args.save_dir, 'checkpoint.th'))

        save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=os.path.join(args.save_dir, 'model.th'))
        statistics.dump_to_file(
            os.path.join(args.save_dir, 'statistics.pickle'))
예제 #5
0
def plot_loop(names, paths, title, save=None, param_dev=None):
    finish = Finish()
    signal.signal(signal.SIGINT, finish)

    plt.rcParams['font.size'] = 18
    plt.rcParams['axes.facecolor'] = 'white'
    plt.rcParams['figure.facecolor'] = 'white'

    while not finish.finished():
        stats = {name: ModelStatistics.load_from_file(path) for name, path in zip(names, paths)}
        param_dev_stats = ModelStatistics.load_from_file(param_dev) if param_dev else None

        if param_dev_stats:
            fig, (plt_loss, plt_param_dev, plt_val_acc) = plt.subplots(3, 1)
            fig.set_size_inches(18, 20)
        else:
            plt_param_dev = None
            fig, (plt_loss, plt_val_acc) = plt.subplots(2, 1)
            fig.set_size_inches(18, 20 * 0.7)
        fig.suptitle(title, fontsize=20)

        plt_loss.set_ylabel('Loss (local)')
        plt_loss.set_yscale('log')
        plt_loss.set_xlabel('Epoch')

        if param_dev_stats:
            plt_param_dev.set_ylabel('Parameter deviation (coef. of variation)')
            plt_param_dev.set_xlabel('Epoch')
            plt_param_dev.set_yscale('log')

        plt_val_acc.set_ylabel('Validation Accuracy, %')
        plt_val_acc.set_xlabel('Epoch')

        for label, stat in stats.items():
            loss = stat.crop('train_loss')
            val_acc = stat.crop('val_precision')

            fmt = {}
            if label.lower().find('consensus') != -1:
                fmt['linestyle'] = 'dashed'
                fmt['linewidth'] = 1.1
            else:
                fmt['linestyle'] = None
                fmt['linewidth'] = 1.5

            plt_loss.plot(range(len(loss)), loss, label=label, **fmt)
            plt_val_acc.plot(range(len(val_acc)), val_acc, label=label + ' ({})'.format(val_acc[-1]), **fmt)

        if param_dev_stats:
            try:
                telemetries_per_epoch = next(iter(param_dev_stats.crop('telemetries_per_epoch')[0].values()))
                try:
                    deviation = param_dev_stats.crop('coef_of_var')
                    plt_param_dev.plot([b / telemetries_per_epoch for b in range(len(deviation))],
                                       deviation, label='max')
                except:
                    pass
                try:
                    cv_pctls = param_dev_stats.crop('abs_coef_of_var_percentiles')
                except:
                    cv_pctls = param_dev_stats.crop('coef_of_var_percentiles')
                grouped_by_pcts = dict()
                for record in cv_pctls:
                    for (pct, val) in record:
                        if pct not in grouped_by_pcts.keys():
                            grouped_by_pcts[pct] = []
                        grouped_by_pcts[pct].append(val)
                for pct, vals in reversed(list(grouped_by_pcts.items())):
                    if pct < 75 or 99 < pct:
                        continue
                    plt_param_dev.plot([b / telemetries_per_epoch for b in range(len(vals))],
                                       vals, label='percentile={}'.format(pct))
            except:
                pass

        plt_loss.legend()
        plt_val_acc.legend()
        if param_dev_stats:
            plt_param_dev.legend()

        fig.tight_layout()
        plt.close(fig)
        clear_output(wait=True)
        display(fig)

        if save is not None:
            fig.savefig(save)

        time.sleep(5.0)
예제 #6
0
async def main(cfg):
    best_prec1 = 0
    torch.manual_seed(239)

    print('Consensus agent: {}'.format(cfg.agent_token))
    consensus_specific = ConsensusSpecific(cfg)
    consensus_specific.init_consensus()

    # Check the save_dir exists or not
    cfg.save_dir = os.path.join(cfg.save_dir, str(cfg.agent_token))
    if not os.path.exists(cfg.save_dir):
        os.makedirs(cfg.save_dir)

    model = torch.nn.DataParallel(resnet.__dict__[cfg.arch]())
    model.cuda()
    print('{}: Created model'.format(cfg.agent_token))

    statistics = ModelStatistics(cfg.agent_token)

    # optionally resume from a checkpoint
    if cfg.do_resume:
        checkpoint_path = os.path.join(cfg.save_dir, 'checkpoint.th')
        if os.path.isfile(checkpoint_path):
            if cfg.logging:
                print("=> loading checkpoint '{}'".format(checkpoint_path))
            checkpoint = torch.load(checkpoint_path)
            cfg.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            if 'statistics' in checkpoint.keys():
                statistics = pickle.loads(checkpoint['statistics'])
            elif os.path.isfile(os.path.join(cfg.save_dir,
                                             'statistics.pickle')):
                statistics = ModelStatistics.load_from_file(
                    os.path.join(cfg.save_dir, 'statistics.pickle'))
            model.load_state_dict(checkpoint['state_dict'])
            if cfg.logging:
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    checkpoint_path, checkpoint['epoch']))
        else:
            if cfg.logging:
                print("=> no checkpoint found at '{}'".format(checkpoint_path))

    cudnn.benchmark = True

    print('{}: Loading dataset...'.format(cfg.agent_token))
    train_loader = get_agent_train_loader(cfg.agent_token, cfg.batch_size)
    print('{}: loaded {} batches for train'.format(cfg.agent_token,
                                                   len(train_loader)))
    val_loader = None if cfg.no_validation else get_agent_val_loader(
        cfg.agent_token)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if cfg.half:
        model.half()
        criterion.half()

    optimizer = torch.optim.SGD(model.parameters(),
                                cfg.lr,
                                momentum=cfg.momentum,
                                weight_decay=cfg.weight_decay)

    def lr_schedule(epoch):
        if cfg.use_lsr and epoch < cfg.warmup:
            factor = np.power(cfg.total_agents, epoch / cfg.warmup)
        else:
            factor = cfg.total_agents if cfg.use_lsr else 1.0
        if epoch >= 81:
            factor /= 10
        if epoch >= 122:
            factor /= 10
        return factor

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                     lr_lambda=lr_schedule)

    if cfg.arch != 'resnet20':
        print(
            'This code was not intended to be used on resnets other than resnet20'
        )

    if cfg.arch in ['resnet1202', 'resnet110']:
        # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
        # then switch back. In this setup it will correspond for first epoch.
        for param_group in optimizer.param_groups:
            param_group['lr'] = cfg.lr * 0.1

    if cfg.evaluate:
        validate(cfg, val_loader, model, criterion)
        return

    await consensus_specific.agent.send_telemetry(
        TelemetryAgentGeneralInfo(
            telemetries_per_epoch=cfg.telemetry_freq_per_epoch))

    for epoch in range(0, cfg.start_epoch):
        lr_scheduler.step()

    for epoch in range(cfg.start_epoch, cfg.epochs):
        # train for one epoch
        if cfg.logging:
            print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        statistics.add('train_begin_timestamp', time.time())
        await train(consensus_specific, train_loader, model, criterion,
                    optimizer, epoch, statistics)
        lr_scheduler.step()
        statistics.add('train_end_timestamp', time.time())

        # evaluate on validation set
        statistics.add('validate_begin_timestamp', time.time())
        prec1 = validate(cfg, val_loader, model, criterion)
        statistics.add('validate_end_timestamp', time.time())
        statistics.add('val_precision', prec1)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % cfg.save_every == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'statistics': pickle.dumps(statistics)
                },
                is_best,
                filename=os.path.join(cfg.save_dir, 'checkpoint.th'))

        save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=os.path.join(cfg.save_dir, 'model.th'))
        statistics.dump_to_file(os.path.join(cfg.save_dir,
                                             'statistics.pickle'))

    consensus_specific.stop_consensus()
예제 #7
0
class ResNet20TelemetryProcessor(TelemetryProcessor):
    def __init__(self, destination_path, topology):
        self.destination_path = destination_path
        self.agents = list(
            set([uv[0] for uv in topology] + [uv[1] for uv in topology]))

        self.stats = ModelStatistics('MASTER TELEMETRY',
                                     save_path=destination_path)
        self.agent_params_by_iter = dict()
        self.agent_general_info = dict()

    def process(self, token, payload):
        if isinstance(payload, TelemetryModelParameters):
            if payload.batch_number not in self.agent_params_by_iter.keys():
                self.agent_params_by_iter[payload.batch_number] = dict()
            self.agent_params_by_iter[
                payload.batch_number][token] = payload.parameters
            if len(self.agent_params_by_iter[payload.batch_number]) == len(
                    self.agents):
                params = self.agent_params_by_iter[payload.batch_number]
                avg_params = np.mean([params[agent] for agent in self.agents],
                                     axis=0)
                deviation_params = {
                    agent: params[agent] - avg_params
                    for agent in self.agents
                }
                self.stats.add(
                    'param_deviation_L1', {
                        agent: np.linalg.norm(deviation_params[agent], ord=1)
                        for agent in self.agents
                    })
                self.stats.add(
                    'param_deviation_L2', {
                        agent: np.linalg.norm(deviation_params[agent], ord=2)
                        for agent in self.agents
                    })
                self.stats.add(
                    'param_deviation_Linf', {
                        agent: np.linalg.norm(deviation_params[agent],
                                              ord=np.inf)
                        for agent in self.agents
                    })

                arr_params = np.array([params[agent] for agent in self.agents])
                max_cv = np.linalg.norm(np.std(arr_params, axis=0) /
                                        np.mean(arr_params, axis=0),
                                        ord=np.inf)
                self.stats.add('coef_of_var', max_cv)

                self.stats.dump_to_file()
                del self.agent_params_by_iter[payload.batch_number]
        elif isinstance(payload, TelemetryAgentGeneralInfo):
            self.agent_general_info[token] = payload
            if len(self.agent_general_info) == len(self.agents):
                self.stats.add(
                    'batches_per_epoch', {
                        agent: self.agent_general_info[agent].batches_per_epoch
                        for agent in self.agents
                    })
                self.stats.dump_to_file()
        else:
            raise ValueError(
                f'Got unsupported payload from {token}: {payload!r}')