Esempio n. 1
0
def get_loaders(traindir, valdir, sz, bs, fp16=True, val_bs=None, workers=8, rect_val=False, min_scale=0.08, distributed=False, synthetic=False):
    val_bs = val_bs or bs
    train_tfms = [
            transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)),
            transforms.RandomHorizontalFlip()
    ]
    train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms))
    train_sampler = (DistributedSampler(train_dataset, num_replicas=env_world_size(), rank=env_rank()) if distributed else None)

    if synthetic:
        print("Using synthetic dataloader")
        train_loader = SyntheticDataLoader(bs, (3, sz, sz))
    elif util.is_set('PYTORCH_USE_SPAWN'):
        print("Using SPAWN method for dataloader")
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=bs, shuffle=(train_sampler is None),
            num_workers=workers, pin_memory=True, collate_fn=fast_collate, 
            sampler=train_sampler,
            multiprocessing_context='spawn')
    else:
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=bs, shuffle=(train_sampler is None),
            num_workers=workers, pin_memory=True, collate_fn=fast_collate, 
            sampler=train_sampler)

    val_dataset, val_sampler = create_validation_set(valdir, val_bs, sz, rect_val=rect_val, distributed=distributed)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        num_workers=workers, pin_memory=True, collate_fn=fast_collate, 
        batch_sampler=val_sampler)

    train_loader = BatchTransformDataLoader(train_loader, fp16=fp16)
    val_loader = BatchTransformDataLoader(val_loader, fp16=fp16)

    return train_loader, val_loader, train_sampler, val_sampler
Esempio n. 2
0
 def __init__(self, batch_size, input_shape):
     # Create two random tensors - one for data and one for label
     self.input_shape = input_shape
     data_shape = (batch_size,) + input_shape
     self.data = torch.randn(data_shape)
     self.labels = torch.from_numpy(np.random.randint(0, 1000, batch_size).astype(np.long))
     self.prefetchable = False
     self.data = self.data.cuda()
     self.labels = self.labels.cuda()
     self.finish = 0
     self.batch_size = batch_size
     self.batch_sampler = SyntheticBatchSampler(batch_size)
     self.batch_num = 1281167 // (env_world_size() * self.batch_sampler.batch_size) + 1
Esempio n. 3
0
 def __next__(self):
     # Support BatchTransformDataLoader.update_batch_size()
     if self.batch_size != self.batch_sampler.batch_size:
         self.batch_size = self.batch_sampler.batch_size
         data_shape = (self.batch_size,) + self.input_shape
         self.data = torch.randn(data_shape)
         self.labels = torch.from_numpy(np.random.randint(0, 1000, self.batch_size).astype(np.long))
         self.data = self.data.cuda()
         self.labels = self.labels.cuda()
         self.batch_num = 1281167 // (env_world_size() * self.batch_sampler.batch_size) + 1
     if self.finish >= self.batch_num:
         self.finish = 0
         raise StopIteration
     self.finish += 1
     return (self.data, self.labels)
Esempio n. 4
0
    def __init__(self, indices, batch_size, distributed=True):
        self.indices = indices
        self.batch_size = batch_size
        if distributed:
            self.world_size = env_world_size()
            self.global_rank = env_rank()
        else:
            self.global_rank = 0
            self.world_size = 1

        # expected number of batches per sample. Need this so each distributed gpu validates on same number of batches.
        # even if there isn't enough data to go around
        self.expected_num_batches = math.ceil(len(self.indices) / self.world_size / self.batch_size)

        # num_samples = total images / world_size. This is what we distribute to each gpu
        self.num_samples = self.expected_num_batches * self.batch_size
Esempio n. 5
0
def get_loaders(traindir,
                valdir,
                sz,
                bs,
                fp16=False,
                val_bs=None,
                workers=8,
                rect_val=False,
                min_scale=0.08,
                distributed=False):
    val_bs = val_bs or bs
    train_tfms = [
        transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)),
        transforms.RandomHorizontalFlip()
    ]
    train_dataset = datasets.ImageFolder(traindir,
                                         transforms.Compose(train_tfms))
    train_sampler = (DistributedSampler(
        train_dataset, num_replicas=env_world_size(), rank=env_rank())
                     if distributed else None)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=bs,
                                               shuffle=(train_sampler is None),
                                               num_workers=workers,
                                               pin_memory=True,
                                               collate_fn=fast_collate,
                                               sampler=train_sampler)

    val_dataset, val_sampler = create_validation_set(valdir,
                                                     val_bs,
                                                     sz,
                                                     rect_val=rect_val,
                                                     distributed=distributed)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             num_workers=workers,
                                             pin_memory=True,
                                             collate_fn=fast_collate,
                                             batch_sampler=val_sampler)

    train_loader = BatchTransformDataLoader(train_loader, fp16=fp16)
    val_loader = BatchTransformDataLoader(val_loader, fp16=fp16)

    return train_loader, val_loader, train_sampler, val_sampler
def main():
    # os.system('shutdown -c')  # cancel previous shutdown command

    if write_log:
        utils.makedirs(args.save)
        logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))

        logger.info(args)

        args_file_path = os.path.join(args.save, 'args.yaml')
        with open(args_file_path, 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

    if args.distributed:
        if write_log: logger.info('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                       world_size=dist_utils.env_world_size(), rank=env_rank())
        assert (dist_utils.env_world_size() == distributed.get_world_size())
        if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size()))
        device = torch.device("cuda:%d" % torch.cuda.current_device() if torch.cuda.is_available() else "cpu")
    else:
        device = torch.cuda.current_device()  #

    # import pdb; pdb.set_trace()
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # load dataset
    train_loader, test_loader, data_shape = get_dataset(args)

    trainlog = os.path.join(args.save, 'training.csv')
    testlog = os.path.join(args.save, 'test.csv')

    traincolumns = ['itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm']
    testcolumns = ['wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost']

    # build model
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = create_model(args, data_shape, regularization_fns).cuda()
    if args.distributed: model = dist_utils.DDP(model,
                                                device_ids=[args.local_rank],
                                                output_device=args.local_rank)

    traincolumns = append_regularization_keys_header(traincolumns, regularization_fns)

    if not args.resume and write_log:
        with open(trainlog, 'w') as f:
            csvlogger = csv.DictWriter(f, traincolumns)
            csvlogger.writeheader()
        with open(testlog, 'w') as f:
            csvlogger = csv.DictWriter(f, testcolumns)
            csvlogger.writeheader()

    set_cnf_options(args, model)

    if write_log: logger.info(model)
    if write_log: logger.info("Number of trainable parameters: {}".format(count_parameters(model)))
    if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader)))
    if write_log: logger.info('Iters per test: {}'.format(len(test_loader)))

    # optimizer
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9,
                              nesterov=False)

    # restore parameters
    # import pdb; pdb.set_trace()
    if args.resume is not None:
        # import pdb; pdb.set_trace()
        print('resume from checkpoint')
        checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpt["state_dict"])
        if "optim_state_dict" in checkpt.keys():
            optimizer.load_state_dict(checkpt["optim_state_dict"])
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)

    # For visualization.
    if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape))

    if write_log:
        time_meter = utils.RunningAverageMeter(0.97)
        bpd_meter = utils.RunningAverageMeter(0.97)
        loss_meter = utils.RunningAverageMeter(0.97)
        steps_meter = utils.RunningAverageMeter(0.97)
        grad_meter = utils.RunningAverageMeter(0.97)
        tt_meter = utils.RunningAverageMeter(0.97)

    if not args.resume:
        best_loss = float("inf")
        itr = 0
        wall_clock = 0.
        begin_epoch = 1
        chkdir = args.save
        '''
    elif args.resume and args.validate:
        chkdir = os.path.dirname(args.resume)
        wall_clock = 0
        itr = 0
        best_loss = 0.0
        begin_epoch = 0
        '''
    else:
        chkdir = os.path.dirname(args.resume)
        filename = os.path.join(chkdir, 'test.csv')
        print(filename)
        tedf = pd.read_csv(os.path.join(chkdir, 'test.csv'))
        trdf = pd.read_csv(os.path.join(chkdir, 'training.csv'))
        # import pdb; pdb.set_trace()
        wall_clock = trdf['wall'].to_numpy()[-1]
        itr = trdf['itr'].to_numpy()[-1]
        best_loss = tedf['bpd'].min()
        begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1)  # not exactly correct

    if args.distributed:
        if write_log: logger.info('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    for epoch in range(begin_epoch, begin_epoch + 1):
        # compute test loss
        print('Evaluating')
        model.eval()
        if args.local_rank == 0:
            utils.makedirs(args.save)
            # import pdb; pdb.set_trace()
            if hasattr(model, 'module'):
                _state = model.module.state_dict()
            else:
                _state = model.state_dict()
            torch.save({
                "args": args,
                "state_dict": _state,  # model.module.state_dict() if torch.cuda.is_available() else model.state_dict(),
                "optim_state_dict": optimizer.state_dict(),
                "fixed_z": fixed_z.cpu()
            }, os.path.join(args.save, "checkpt_%d.pth" % epoch))

        # save real and generate with different temperatures
        fig_num = 64
        if True:  # args.save_real:
            for i, (x, y) in enumerate(test_loader):
                if i < 100:
                    pass
                elif i == 100:
                    real = x.size(0)
                else:
                    break
            if x.shape[0] > fig_num:
                x = x[:fig_num, ...]
            # import pdb; pdb.set_trace()
            fig_filename = os.path.join(chkdir, "real.jpg")
            save_image(x.float() / 255.0, fig_filename, nrow=8)

        if True:  # args.generate:
            print('\nGenerating images... ')
            fixed_z = cvt(torch.randn(fig_num, *data_shape))
            nb = int(np.ceil(np.sqrt(float(fixed_z.size(0)))))
            for t in [ 1.0, 0.99, 0.98, 0.97,0.96,0.95,0.93,0.92,0.90,0.85,0.8,0.75,0.7,0.65,0.6]:
                # visualize samples and density
                fig_filename = os.path.join(chkdir, "generated-T%g.jpg" % t)
                utils.makedirs(os.path.dirname(fig_filename))
                generated_samples = model(t * fixed_z, reverse=True)
                x = unshift(generated_samples[0].view(-1, *data_shape), 8)
                save_image(x, fig_filename, nrow=nb)
def get_dataset(args):
    trans = lambda im_size: tforms.Compose([tforms.Resize(im_size)])

    if args.data == "mnist":
        im_dim = 1
        im_size = 28 if args.imagesize is None else args.imagesize
        train_set = dset.MNIST(root=args.datadir, train=True, transform=trans(im_size), download=True)
        test_set = dset.MNIST(root=args.datadir, train=False, transform=trans(im_size), download=True)
    elif args.data == "svhn":
        im_dim = 3
        im_size = 32 if args.imagesize is None else args.imagesize
        train_set = dset.SVHN(root=args.datadir, split="train", transform=trans(im_size), download=True)
        test_set = dset.SVHN(root=args.datadir, split="test", transform=trans(im_size), download=True)
    elif args.data == "cifar10":
        im_dim = 3
        im_size = 32 if args.imagesize is None else args.imagesize
        train_set = dset.CIFAR10(
            root=args.datadir, train=True, transform=tforms.Compose([
                tforms.Resize(im_size),
                tforms.RandomHorizontalFlip(),
            ]), download=True
        )
        test_set = dset.CIFAR10(root=args.datadir, train=False, transform=None, download=True)
    elif args.data == 'celebahq':
        im_dim = 3
        im_size = 256 if args.imagesize is None else args.imagesize
        ''' 
        train_set = CelebAHQ(
            train=True, root=args.datadir, transform=tforms.Compose([
                tforms.ToPILImage(),
                tforms.Resize(im_size),
                tforms.RandomHorizontalFlip(),
            ])
        )
        test_set = CelebAHQ(
            train=False, root=args.datadir,  transform=tforms.Compose([
                tforms.ToPILImage(),
                tforms.Resize(im_size),
            ])
        )
        '''
        train_set = CelebAHQ(train=True, root=args.datadir)
        test_set = CelebAHQ(train=False, root=args.datadir)
    elif args.data == 'imagenet64':
        im_dim = 3
        if args.imagesize != 64:
            args.imagesize = 64
        im_size = 64
        train_set = Imagenet64(train=True, root=args.datadir)
        test_set = Imagenet64(train=False, root=args.datadir)
    elif args.data == 'lsun_church':
        im_dim = 3
        im_size = 64 if args.imagesize is None else args.imagesize
        train_set = dset.LSUN(
            'data', ['church_outdoor_train'], transform=tforms.Compose([
                tforms.Resize(96),
                tforms.RandomCrop(64),
                tforms.Resize(im_size),
            ])
        )
        test_set = dset.LSUN(
            'data', ['church_outdoor_val'], transform=tforms.Compose([
                tforms.Resize(96),
                tforms.RandomCrop(64),
                tforms.Resize(im_size),
            ])
        )
    data_shape = (im_dim, im_size, im_size)

    def fast_collate(batch):

        imgs = [img[0] for img in batch]
        targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
        w = imgs[0].size[0]
        h = imgs[0].size[1]

        tensor = torch.zeros((len(imgs), im_dim, im_size, im_size), dtype=torch.uint8)
        for i, img in enumerate(imgs):
            nump_array = np.asarray(img, dtype=np.uint8)
            tens = torch.from_numpy(nump_array)
            if (nump_array.ndim < 3):
                nump_array = np.expand_dims(nump_array, axis=-1)
            nump_array = np.rollaxis(nump_array, 2)
            tensor[i] += torch.from_numpy(nump_array)

        return tensor, targets

    train_sampler = (DistributedSampler(train_set,
                                        num_replicas=env_world_size(), rank=env_rank()) if args.distributed
                     else None)

    if not args.distributed:
        train_loader = torch.utils.data.DataLoader(
            dataset=train_set, batch_size=args.batch_size, shuffle=True,
            num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate
        )
    else:
        train_loader = torch.utils.data.DataLoader(
            dataset=train_set, batch_size=args.batch_size, sampler=train_sampler,
            num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate
        )

    # import pdb
    # pdb.set_trace()

    test_sampler = (DistributedSampler(test_set,
                                       num_replicas=env_world_size(), rank=env_rank()) if args.distributed
                    else None)

    if not args.distributed:
        test_loader = torch.utils.data.DataLoader(
            dataset=test_set, batch_size=args.test_batch_size, shuffle=False,
            num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate
        )
    else:
        test_loader = torch.utils.data.DataLoader(
            dataset=test_set, batch_size=args.test_batch_size,
            num_workers=args.nworkers, pin_memory=True, sampler=test_sampler, collate_fn=fast_collate
        )
    return train_loader, test_loader, data_shape
Esempio n. 8
0
def train(trn_loader, model, criterion, optimizer, scheduler, epoch):
    net_meter = NetworkMeter()
    timer = TimeMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()
    for i, (input, target) in enumerate(trn_loader):
        if args.short_epoch and (i > 10): break
        batch_num = i + 1
        timer.batch_start()
        scheduler.update_lr(epoch, i + 1, len(trn_loader))

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # compute gradient and do SGD step
        if args.fp16:
            loss = loss * args.loss_scale
            model.zero_grad()
            loss.backward()
            model_grads_to_master_grads(model_params, master_params)
            for param in master_params:
                param.grad.data = param.grad.data / args.loss_scale
            optimizer.step()
            master_params_to_model_params(model_params, master_params)
            loss = loss / args.loss_scale
        else:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Train batch done. Logging results
        timer.batch_end()
        corr1, corr5 = correct(output.data, target, topk=(1, 5))
        reduced_loss, batch_total = to_python_float(
            loss.data), to_python_float(input.size(0))
        if args.distributed:  # Must keep track of global batch size, since not all machines are guaranteed equal batches at the end of an epoch
            metrics = torch.tensor([batch_total, reduced_loss, corr1,
                                    corr5]).float().cuda()
            batch_total, reduced_loss, corr1, corr5 = dist_utils.sum_tensor(
                metrics).cpu().numpy()
            reduced_loss = reduced_loss / dist_utils.env_world_size()
        top1acc = to_python_float(corr1) * (100.0 / batch_total)
        top5acc = to_python_float(corr5) * (100.0 / batch_total)

        losses.update(reduced_loss, batch_total)
        top1.update(top1acc, batch_total)
        top5.update(top5acc, batch_total)

        should_print = (batch_num % args.print_freq
                        == 0) or (batch_num == len(trn_loader))
        if args.local_rank == 0 and should_print:
            tb.log_memory()
            tb.log_trn_times(timer.batch_time.val, timer.data_time.val,
                             input.size(0))
            tb.log_trn_loss(losses.val, top1.val, top5.val)

            recv_gbit, transmit_gbit = net_meter.update_bandwidth()
            tb.log("sizes/batch_total", batch_total)
            tb.log('net/recv_gbit', recv_gbit)
            tb.log('net/transmit_gbit', transmit_gbit)

            output = (
                f'Epoch: [{epoch}][{batch_num}/{len(trn_loader)}]\t'
                f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t'
                f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                f'Data {timer.data_time.val:.3f} ({timer.data_time.avg:.3f})\t'
                f'BW {recv_gbit:.3f} {transmit_gbit:.3f}')
            log.verbose(output)

        tb.update_step_count(batch_total)
Esempio n. 9
0
def main():
    os.system('shutdown -c')  # cancel previous shutdown command
    log.console(args)
    tb.log('sizes/world', dist_utils.env_world_size())

    # need to index validation directory before we start counting the time
    dataloader.sort_ar(args.data + '/validation')

    if args.distributed:
        log.console('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=dist_utils.env_world_size())
        assert (dist_utils.env_world_size() == dist.get_world_size())
        log.console("Distributed: success (%d/%d)" %
                    (args.local_rank, dist.get_world_size()))

    log.console("Loading model")
    model = resnet.resnet50(bn0=args.init_bn0).cuda()
    if args.fp16: model = network_to_half(model)
    if args.distributed:
        model = dist_utils.DDP(model,
                               device_ids=[args.local_rank],
                               output_device=args.local_rank)
    best_top5 = 93  # only save models over 93%. Otherwise it stops to save every time

    global model_params, master_params
    if args.fp16: model_params, master_params = prep_param_lists(model)
    else: model_params = master_params = model.parameters()

    optim_params = experimental_utils.bnwd_optim_params(
        model, model_params, master_params) if args.no_bn_wd else master_params

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        optim_params,
        0,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )  # start with 0 lr. Scheduler will change this later

    if args.resume:
        checkpoint = torch.load(
            args.resume,
            map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch']
        best_top5 = checkpoint['best_top5']
        optimizer.load_state_dict(checkpoint['optimizer'])

    # save script so we can reproduce from logs
    shutil.copy2(os.path.realpath(__file__), f'{args.logdir}')

    log.console(
        "Creating data loaders (this could take up to 10 minutes if volume needs to be warmed up)"
    )
    phases = eval(args.phases)
    dm = DataManager([copy.deepcopy(p) for p in phases if 'bs' in p])
    scheduler = Scheduler(optimizer,
                          [copy.deepcopy(p) for p in phases if 'lr' in p])

    start_time = datetime.now()  # Loading start to after everything is loaded
    if args.evaluate:
        return validate(dm.val_dl, model, criterion, 0, start_time)

    if args.distributed:
        log.console('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    log.event("~~epoch\thours\ttop1\ttop5\n")
    for epoch in range(args.start_epoch, scheduler.tot_epochs):
        dm.set_epoch(epoch)

        train(dm.trn_dl, model, criterion, optimizer, scheduler, epoch)
        top1, top5 = validate(dm.val_dl, model, criterion, epoch, start_time)

        time_diff = (datetime.now() - start_time).total_seconds() / 3600.0
        log.event(f'~~{epoch}\t{time_diff:.5f}\t\t{top1:.3f}\t\t{top5:.3f}\n')

        is_best = top5 > best_top5
        best_top5 = max(top5, best_top5)
        if args.local_rank == 0:
            if is_best:
                save_checkpoint(epoch,
                                model,
                                best_top5,
                                optimizer,
                                is_best=True,
                                filename='model_best.pth.tar')
            phase = dm.get_phase(epoch)
            if phase:
                save_checkpoint(
                    epoch,
                    model,
                    best_top5,
                    optimizer,
                    filename=f'sz{phase["bs"]}_checkpoint.path.tar')
Esempio n. 10
0
 def __len__(self):
     return 1281167 // (env_world_size() * self.batch_sampler.batch_size) + 1
Esempio n. 11
0
def main():
    #os.system('shutdown -c')  # cancel previous shutdown command

    if write_log:
        utils.makedirs(args.save)
        logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'),
                                  filepath=os.path.abspath(__file__))

        logger.info(args)

        args_file_path = os.path.join(args.save, 'args.yaml')
        with open(args_file_path, 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

    if args.distributed:
        if write_log: logger.info('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        distributed.init_process_group(backend=args.dist_backend,
                                       init_method=args.dist_url,
                                       world_size=dist_utils.env_world_size(),
                                       rank=env_rank())
        assert (dist_utils.env_world_size() == distributed.get_world_size())
        if write_log:
            logger.info("Distributed: success (%d/%d)" %
                        (args.local_rank, distributed.get_world_size()))

    # get deivce
    # device = torch.device("cuda:%d"%torch.cuda.current_device() if torch.cuda.is_available() else "cpu")
    device = "cpu"
    cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

    # load dataset
    train_loader, test_loader, data_shape = get_dataset(args)

    trainlog = os.path.join(args.save, 'training.csv')
    testlog = os.path.join(args.save, 'test.csv')

    traincolumns = [
        'itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time',
        'grad_norm'
    ]
    testcolumns = [
        'wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time',
        'transport_cost'
    ]

    # build model
    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = create_model(args, data_shape, regularization_fns)
    # model = model.cuda()
    if args.distributed:
        model = dist_utils.DDP(model,
                               device_ids=[args.local_rank],
                               output_device=args.local_rank)

    traincolumns = append_regularization_keys_header(traincolumns,
                                                     regularization_fns)

    if not args.resume and write_log:
        with open(trainlog, 'w') as f:
            csvlogger = csv.DictWriter(f, traincolumns)
            csvlogger.writeheader()
        with open(testlog, 'w') as f:
            csvlogger = csv.DictWriter(f, testcolumns)
            csvlogger.writeheader()

    set_cnf_options(args, model)

    if write_log: logger.info(model)
    if write_log:
        logger.info("Number of trainable parameters: {}".format(
            count_parameters(model)))
    if write_log:
        logger.info('Iters per train epoch: {}'.format(len(train_loader)))
    if write_log: logger.info('Iters per test: {}'.format(len(test_loader)))

    # optimizer
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay,
                              momentum=0.9,
                              nesterov=False)

    # restore parameters
    if args.resume is not None:
        checkpt = torch.load(
            args.resume,
            map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpt["state_dict"])
        if "optim_state_dict" in checkpt.keys():
            optimizer.load_state_dict(checkpt["optim_state_dict"])
            # Manually move optimizer state to device.
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cvt(v)

    # For visualization.
    if write_log:
        fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape))

    if write_log:
        time_meter = utils.RunningAverageMeter(0.97)
        bpd_meter = utils.RunningAverageMeter(0.97)
        loss_meter = utils.RunningAverageMeter(0.97)
        steps_meter = utils.RunningAverageMeter(0.97)
        grad_meter = utils.RunningAverageMeter(0.97)
        tt_meter = utils.RunningAverageMeter(0.97)

    if not args.resume:
        best_loss = float("inf")
        itr = 0
        wall_clock = 0.
        begin_epoch = 1
    else:
        chkdir = os.path.dirname(args.resume)
        tedf = pd.read_csv(os.path.join(chkdir, 'test.csv'))
        trdf = pd.read_csv(os.path.join(chkdir, 'training.csv'))
        wall_clock = trdf['wall'].to_numpy()[-1]
        itr = trdf['itr'].to_numpy()[-1]
        best_loss = tedf['bpd'].min()
        begin_epoch = int(tedf['epoch'].to_numpy()[-1] +
                          1)  # not exactly correct

    if args.distributed:
        if write_log: logger.info('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    for epoch in range(begin_epoch, args.num_epochs + 1):
        if not args.validate:
            model.train()

            with open(trainlog, 'a') as f:
                if write_log: csvlogger = csv.DictWriter(f, traincolumns)

                for _, (x, y) in enumerate(train_loader):
                    start = time.time()
                    update_lr(optimizer, itr)
                    optimizer.zero_grad()

                    # cast data and move to device
                    x = add_noise(cvt(x), nbits=args.nbits)
                    #x = x.clamp_(min=0, max=1)
                    # compute loss
                    bpd, (x, z), reg_states = compute_bits_per_dim(x, model)
                    if np.isnan(bpd.data.item()):
                        raise ValueError('model returned nan during training')
                    elif np.isinf(bpd.data.item()):
                        raise ValueError('model returned inf during training')

                    loss = bpd
                    if regularization_coeffs:
                        reg_loss = sum(reg_state * coeff
                                       for reg_state, coeff in zip(
                                           reg_states, regularization_coeffs)
                                       if coeff != 0)
                        loss = loss + reg_loss
                    total_time = count_total_time(model)

                    loss.backward()
                    nfe_opt = count_nfe(model)
                    if write_log: steps_meter.update(nfe_opt)
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), args.max_grad_norm)

                    optimizer.step()

                    itr_time = time.time() - start
                    wall_clock += itr_time

                    batch_size = x.size(0)
                    metrics = torch.tensor([
                        1., batch_size,
                        loss.item(),
                        bpd.item(), nfe_opt, grad_norm, *reg_states
                    ]).float()

                    rv = tuple(torch.tensor(0.) for r in reg_states)

                    total_gpus, batch_total, r_loss, r_bpd, r_nfe, r_grad_norm, *rv = dist_utils.sum_tensor(
                        metrics).cpu().numpy()

                    if write_log:
                        time_meter.update(itr_time)
                        bpd_meter.update(r_bpd / total_gpus)
                        loss_meter.update(r_loss / total_gpus)
                        grad_meter.update(r_grad_norm / total_gpus)
                        tt_meter.update(total_time)

                        fmt = '{:.4f}'
                        logdict = {
                            'itr': itr,
                            'wall': fmt.format(wall_clock),
                            'itr_time': fmt.format(itr_time),
                            'loss': fmt.format(r_loss / total_gpus),
                            'bpd': fmt.format(r_bpd / total_gpus),
                            'total_time': fmt.format(total_time),
                            'fe': r_nfe / total_gpus,
                            'grad_norm': fmt.format(r_grad_norm / total_gpus),
                        }
                        if regularization_coeffs:
                            rv = tuple(v_ / total_gpus for v_ in rv)
                            logdict = append_regularization_csv_dict(
                                logdict, regularization_fns, rv)
                        csvlogger.writerow(logdict)

                        if itr % args.log_freq == 0:
                            log_message = (
                                "Itr {:06d} | Wall {:.3e}({:.2f}) | "
                                "Time/Itr {:.2f}({:.2f}) | BPD {:.2f}({:.2f}) | "
                                "Loss {:.2f}({:.2f}) | "
                                "FE {:.0f}({:.0f}) | Grad Norm {:.3e}({:.3e}) | "
                                "TT {:.2f}({:.2f})".format(
                                    itr, wall_clock, wall_clock / (itr + 1),
                                    time_meter.val, time_meter.avg,
                                    bpd_meter.val, bpd_meter.avg,
                                    loss_meter.val, loss_meter.avg,
                                    steps_meter.val, steps_meter.avg,
                                    grad_meter.val, grad_meter.avg,
                                    tt_meter.val, tt_meter.avg))
                            if regularization_coeffs:
                                log_message = append_regularization_to_log(
                                    log_message, regularization_fns, rv)
                            logger.info(log_message)

                    itr += 1

        # compute test loss
        model.eval()
        if args.local_rank == 0:
            utils.makedirs(args.save)
            torch.save(
                {
                    "args":
                    args,
                    "state_dict":
                    model.module.state_dict()
                    if torch.cuda.is_available() else model.state_dict(),
                    "optim_state_dict":
                    optimizer.state_dict(),
                    "fixed_z":
                    fixed_z.cpu()
                }, os.path.join(args.save, "checkpt.pth"))
        if epoch % args.val_freq == 0 or args.validate:
            with open(testlog, 'a') as f:
                if write_log: csvlogger = csv.DictWriter(f, testcolumns)
                with torch.no_grad():
                    start = time.time()
                    if write_log: logger.info("validating...")

                    lossmean = 0.
                    meandist = 0.
                    steps = 0
                    tt = 0.
                    for i, (x, y) in enumerate(test_loader):
                        sh = x.shape
                        x = shift(cvt(x), nbits=args.nbits)
                        loss, (x, z), _ = compute_bits_per_dim(x, model)
                        dist = (x.view(x.size(0), -1) -
                                z).pow(2).mean(dim=-1).mean()
                        meandist = i / (i + 1) * dist + meandist / (i + 1)
                        lossmean = i / (i + 1) * lossmean + loss / (i + 1)

                        tt = i / (i + 1) * tt + count_total_time(model) / (i +
                                                                           1)
                        steps = i / (i + 1) * steps + count_nfe(model) / (i +
                                                                          1)

                    loss = lossmean.item()
                    metrics = torch.tensor([1., loss, meandist, steps]).float()

                    total_gpus, r_bpd, r_mdist, r_steps = dist_utils.sum_tensor(
                        metrics).cpu().numpy()
                    eval_time = time.time() - start

                    if write_log:
                        fmt = '{:.4f}'
                        logdict = {
                            'epoch': epoch,
                            'eval_time': fmt.format(eval_time),
                            'bpd': fmt.format(r_bpd / total_gpus),
                            'wall': fmt.format(wall_clock),
                            'total_time': fmt.format(tt),
                            'transport_cost': fmt.format(r_mdist / total_gpus),
                            'fe': '{:.2f}'.format(r_steps / total_gpus)
                        }

                        csvlogger.writerow(logdict)

                        logger.info(
                            "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}, Steps {:.4f}, TT {:.2f}, Transport Cost {:.2e}"
                            .format(epoch, eval_time, r_bpd / total_gpus,
                                    r_steps / total_gpus, tt,
                                    r_mdist / total_gpus))

                    loss = r_bpd / total_gpus

                    if loss < best_loss and args.local_rank == 0:
                        best_loss = loss
                        shutil.copyfile(os.path.join(args.save, "checkpt.pth"),
                                        os.path.join(args.save, "best.pth"))

            # visualize samples and density
            if write_log:
                with torch.no_grad():
                    fig_filename = os.path.join(args.save, "figs",
                                                "{:04d}.jpg".format(epoch))
                    utils.makedirs(os.path.dirname(fig_filename))
                    generated_samples, _, _ = model(fixed_z, reverse=True)
                    generated_samples = generated_samples.view(-1, *data_shape)
                    nb = int(np.ceil(np.sqrt(float(fixed_z.size(0)))))
                    save_image(unshift(generated_samples, nbits=args.nbits),
                               fig_filename,
                               nrow=nb)
            if args.validate:
                break
Esempio n. 12
0
def main():
    # os.system('sudo shutdown -c')  # cancel previous shutdown command
    log.console(args)
    tb.log('sizes/world', dist_utils.env_world_size())

    print(args.data)
    assert os.path.exists(args.data)

    # need to index validation directory before we start counting the time
    dataloader.sort_ar(args.data + '/val')

    if args.distributed:
        log.console('Distributed initializing process group')
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=dist_utils.env_world_size())
        assert (dist_utils.env_world_size() == dist.get_world_size())
        # todo(y): use global_rank instead of local_rank here
        log.console("Distributed: success (%d/%d)" %
                    (args.local_rank, dist.get_world_size()))

    log.console("Loading model")
    #from mobilenetv3 import MobileNetV3
    #model = MobileNetV3(mode='small', num_classes=1000).cuda()
    if args.network == 'resnet50':
        model.resnet.resnet50(bn0=args.init_bn0).cuda()
    elif args.network == 'resnet50friendlyv1':
        model = resnet.resnet50friendly(bn0=args.init_bn0, hybrid=True).cuda()
    elif args.network == 'resnet50friendlyv2':
        model = resnet.resnet50friendly2(bn0=args.init_bn0, hybrid=True).cuda()
    elif args.network == 'resnet50friendlyv3':
        model = resnet.resnet50friendly3(bn0=args.init_bn0, hybrid=True).cuda()
    elif args.network == 'resnet50friendlyv4':
        model = resnet.resnet50friendly4(bn0=args.init_bn0, hybrid=True).cuda()
    #import resnet_friendly
    #model = resnet_friendly.ResNet50Friendly().cuda()
    #model = torchvision.models.mobilenet_v2(pretrained=False).cuda()
    if args.fp16:
        model = network_to_half(model)
    if args.distributed:
        model = dist_utils.DDP(model,
                               device_ids=[args.local_rank],
                               output_device=args.local_rank)
    best_top5 = 93  # only save models over 93%. Otherwise it stops to save every time

    global model_params, master_params
    if args.fp16:
        model_params, master_params = prep_param_lists(model)
    else:
        model_params = master_params = model.parameters()

    optim_params = experimental_utils.bnwd_optim_params(
        model, model_params, master_params) if args.no_bn_wd else master_params

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        optim_params,
        0,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )  # start with 0 lr. Scheduler will change this later

    if args.resume:
        checkpoint = torch.load(
            args.resume,
            map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch']
        current_phase = checkpoint['current_phase']
        best_top5 = checkpoint['best_top5']
        optimizer.load_state_dict(checkpoint['optimizer'])

    # save script so we can reproduce from logs
    shutil.copy2(os.path.realpath(__file__), f'{args.logdir}')

    log.console(
        "Creating data loaders (this could take up to 10 minutes if volume needs to be warmed up)"
    )
    # phases = util.text_unpickle(args.phases)
    lr = 0.9
    scale_224 = 224 / 512
    scale_288 = 128 / 512
    one_machine = [
        {
            'ep': 0,
            'sz': 128,
            'bs': 512,
            'trndir': ''
        },  # Will this work?  -- No idea! Should we try with mv2 baseline? ???
        {
            'ep': (0, 5),
            'lr': (lr, lr * 2)
        },  # lr warmup is better with --init-bn0
        {
            'ep': 5,
            'lr': lr
        },
        {
            'ep': 14,
            'sz': 224,
            'bs': 224,
            'lr': lr * scale_224
        },
        {
            'ep': 16,
            'lr': lr / 10 * scale_224
        },
        {
            'ep': 32,
            'lr': lr / 100 * scale_224
        },
        {
            'ep': 37,
            'lr': lr / 100 * scale_224
        },
        {
            'ep': 39,
            'sz': 288,
            'bs': 128,
            'min_scale': 0.5,
            'rect_val': True,
            'lr': lr / 100 * scale_288
        },
        {
            'ep': (40, 44),
            'lr': lr / 1000 * scale_288
        },
        #{'ep': (36, 40), 'lr': lr / 1000 * scale_288},
        {
            'ep': (45, 48),
            'lr': lr / 10000 * scale_288
        },
        {
            'ep': (49, 52),
            'sz': 288,
            'bs': 224,
            'lr': lr / 10000 * scale_224
        }
        #{'ep': (46, 50), 'sz': 320, 'bs': 64,  'lr': lr / 10000 * scale_320}
    ]
    phases = util.text_pickle(one_machine)  #Ok? Unpickle?
    phases = util.text_unpickle(phases)
    dm = DataManager([copy.deepcopy(p) for p in phases if 'bs' in p])
    scheduler = Scheduler(optimizer,
                          [copy.deepcopy(p) for p in phases if 'lr' in p])

    start_time = datetime.now()  # Loading start to after everything is loaded
    if args.evaluate:
        return validate(dm.val_dl, model, criterion, 0, start_time)

    if args.distributed:
        log.console('Syncing machines before training')
        dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())

    log.event("~~epoch\thours\ttop1\ttop5\n")
    for epoch in range(args.start_epoch, scheduler.tot_epochs):
        print(" The start epoch:", args.start_epoch)
        dm.set_epoch(epoch)

        train(dm.trn_dl, model, criterion, optimizer, scheduler, epoch)
        top1, top5 = validate(dm.val_dl, model, criterion, epoch, start_time)

        time_diff = (datetime.now() - start_time).total_seconds() / 3600.0
        log.event(f'~~{epoch}\t{time_diff:.5f}\t\t{top1:.3f}\t\t{top5:.3f}\n')

        is_best = top5 > best_top5
        best_top5 = max(top5, best_top5)
        phase_save = dm.get_phase(epoch)
        if args.local_rank == 0:
            if is_best:
                save_checkpoint(phase_save,
                                epoch,
                                model,
                                best_top5,
                                optimizer,
                                is_best=True,
                                filename='model_best_' + args.network +
                                args.name + '.pth.tar')
            else:
                save_checkpoint(phase_save,
                                epoch,
                                model,
                                top5,
                                optimizer,
                                is_best=False,
                                filename='model_epoch_latest_' + args.network +
                                args.name + '.pth.tar')
            phase = dm.get_phase(epoch)
            if phase:
                save_checkpoint(
                    phase_save,
                    epoch,
                    model,
                    best_top5,
                    optimizer,
                    filename=f'sz{phase["bs"]}_checkpoint.path.tar')