예제 #1
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    torch.manual_seed(239)

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
    model.cuda()

    statistics = ModelStatistics('Single model')

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            if 'statistics' in checkpoint.keys():
                statistics = pickle.loads(checkpoint['statistics'])
            elif os.path.isfile(os.path.join(args.resume,
                                             'statistics.pickle')):
                statistics = ModelStatistics.load_from_file(
                    os.path.join(args.resume, 'statistics.pickle'))
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        root='./data',
        train=True,
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]),
        download=True),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        root='./data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=128,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.half:
        model.half()
        criterion.half()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    def lr_schedule(epoch):
        factor = 1
        if epoch >= 81:
            factor /= 10
        if epoch >= 122:
            factor /= 10
        return factor

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                     lr_lambda=lr_schedule)

    if args.arch != 'resnet20':
        print(
            'This code was not intended to be used on resnets other than resnet20'
        )

    if args.arch in ['resnet1202', 'resnet110']:
        # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
        # then switch back. In this setup it will correspond for first epoch.
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr * 0.1

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        statistics.set_epoch(epoch)
        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        statistics.add('train_begin_timestamp', time.time())
        train(train_loader, model, criterion, optimizer, epoch, statistics)
        lr_scheduler.step()
        statistics.add('train_end_timestamp', time.time())

        # evaluate on validation set
        statistics.add('validate_begin_timestamp', time.time())
        prec1 = validate(val_loader, model, criterion)
        statistics.add('validate_end_timestamp', time.time())
        statistics.add('val_precision', prec1)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'statistics': pickle.dumps(statistics)
                },
                is_best,
                filename=os.path.join(args.save_dir, 'checkpoint.th'))

        save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1
            },
            is_best,
            filename=os.path.join(args.save_dir, 'model.th'))
        statistics.dump_to_file(
            os.path.join(args.save_dir, 'statistics.pickle'))
async def main():
    global args, best_prec1
    args = parser.parse_args()

    torch.manual_seed(239)

    print('Consensus agent: {}'.format(args.agent_token))
    convergence_eps = 1e-4
    agent = ConsensusAgent(args.agent_token,
                           args.agent_host,
                           args.agent_port,
                           args.master_host,
                           args.master_port,
                           convergence_eps=convergence_eps,
                           debug=True if args.debug else False)
    agent_serve_task = asyncio.create_task(agent.serve_forever())
    print('{}: Created serving task'.format(args.agent_token))

    # Check the save_dir exists or not
    args.save_dir = os.path.join(args.save_dir, str(args.agent_token))
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
    model.cuda()

    statistics = ModelStatistics(args.agent_token)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            if args.logging:
                print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            if 'statistics' in checkpoint.keys():
                statistics = pickle.loads(checkpoint['statistics'])
            elif os.path.isfile(os.path.join(args.resume,
                                             'statistics.pickle')):
                statistics = ModelStatistics.load_from_file(
                    os.path.join(args.resume, 'statistics.pickle'))
            model.load_state_dict(checkpoint['state_dict'])
            if args.logging:
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.evaluate, checkpoint['epoch']))
        else:
            if args.logging:
                print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    dataset_path = os.path.join('./data/', str(args.agent_token))
    train_dataset = datasets.CIFAR10(root=dataset_path,
                                     train=True,
                                     transform=transforms.Compose([
                                         transforms.RandomHorizontalFlip(),
                                         transforms.RandomCrop(32, 4),
                                         transforms.ToTensor(),
                                         normalize,
                                     ]),
                                     download=True)

    size_per_agent = len(train_dataset) // args.total_agents
    train_indices = list(
        range(args.agent_token * size_per_agent,
              min(len(train_dataset),
                  (args.agent_token + 1) * size_per_agent)))

    if args.target_split:
        train_indices = list(range(
            len(train_dataset)))[train_dataset.targets == args.agent_token]
        print('Target split: {} samples for agent {}'.format(
            len(train_indices), args.agent_token))

    from torch.utils.data.sampler import SubsetRandomSampler
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,  # !!!!!
        num_workers=args.workers,
        pin_memory=True,
        sampler=SubsetRandomSampler(train_indices))

    val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        root=dataset_path,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=128,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.half:
        model.half()
        criterion.half()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    def lr_schedule(epoch):
        factor = args.total_agents
        if epoch >= 81:
            factor /= 10
        if epoch >= 122:
            factor /= 10
        return factor

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                     lr_lambda=lr_schedule)

    if args.arch != 'resnet20':
        print(
            'This code was not intended to be used on resnets other than resnet20'
        )

    if args.arch in ['resnet1202', 'resnet110']:
        # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
        # then switch back. In this setup it will correspond for first epoch.
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr * 0.1

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    def dump_params(model):
        return torch.cat([
            v.to(torch.float32).view(-1)
            for k, v in model.state_dict().items()
        ]).cpu().numpy()

    def load_params(model, params):
        st = model.state_dict()
        used_params = 0
        for k in st.keys():
            cnt_params = st[k].numel()
            st[k] = torch.Tensor(params[used_params:used_params + cnt_params]).view(st[k].shape)\
                .to(st[k].dtype).to(st[k].device)
            used_params += cnt_params
        model.load_state_dict(st)

    async def run_averaging():
        params = dump_params(model)
        params = await agent.run_once(params)
        load_params(model, params)

    if args.logging:
        print('Starting initial averaging...')

    params = dump_params(model)
    params = await agent.run_round(params, 1.0 if args.init_leader else 0.0)
    load_params(model, params)

    if args.logging:
        print('Initial averaging completed!')

    for epoch in range(args.start_epoch, args.epochs):
        statistics.set_epoch(epoch)
        # train for one epoch
        if args.logging:
            print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        statistics.add('train_begin_timestamp', time.time())
        await train(train_loader, model, criterion, optimizer, epoch,
                    statistics, run_averaging)
        lr_scheduler.step()
        statistics.add('train_end_timestamp', time.time())

        # evaluate on validation set
        statistics.add('validate_begin_timestamp', time.time())
        prec1 = validate(val_loader, model, criterion)
        statistics.add('validate_end_timestamp', time.time())
        statistics.add('val_precision', prec1)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'statistics': pickle.dumps(statistics)
                },
                is_best,
                filename=os.path.join(args.save_dir, 'checkpoint.th'))

        save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=os.path.join(args.save_dir, 'model.th'))
        statistics.dump_to_file(
            os.path.join(args.save_dir, 'statistics.pickle'))