Example #1
0
    def step(self, epoch=None):
        current_lr = self.get_lr()[0]

        # Outside of warmup epochs, we use the same learning rate for every step
        # in an epoch. Don't bother updating learning rate if it hasn't changed.
        if abs(current_lr - self._previous_lr) >= self._min_delta_to_update_lr:
            super(WarmupAndExponentialDecayScheduler, self).step()
            self._previous_lr = current_lr
        else:
            self._step_count += 1  # This normally happens in super().step().

        # Add current learning rate to Tensorboard metrics. For warmup epochs,
        # log the learning rate at every step. For non-warmup epochs, log only
        # the first step since the entire epoch will use the same learning rate.
        if self._summary_writer:
            if self._epoch() < self._num_warmup_epochs or (
                    self._step_count % self._num_steps_per_epoch == 0):
                test_utils.add_scalar_to_summary(
                    self._summary_writer, 'LearningRate',
                    self.optimizer.param_groups[0]['lr'], self._step_count)
def train_imagenet():
    print('==> Preparing data..')
    img_dim = get_model_property('img_dim')
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    device = xm.xla_device()
    model = get_model_property('model_fn')().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(),
                          lr=FLAGS.lr,
                          momentum=FLAGS.momentum,
                          weight_decay=1e-4)
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         xm.xrt_world_size())
    lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
        optimizer,
        scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
        scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
        scheduler_divide_every_n_epochs=getattr(
            FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
        num_steps_per_epoch=num_training_steps_per_epoch,
        summary_writer=writer)
    loss_fn = nn.CrossEntropyLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print("Finished training epoch {}".format(epoch))

        para_loader = pl.ParallelLoader(test_loader, [device])
        accuracy = test_loop_fn(para_loader.per_device_loader(device))
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         epoch)

        if FLAGS.metrics_debug:
            print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    return accuracy
Example #3
0
def train_imagenet():
    print('==> Preparing data..')
    img_dim = get_model_property('img_dim')
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            sampler=test_sampler,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    torchvision_model = get_model_property('model_fn')
    model_parallel = dp.DataParallel(torchvision_model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(model.parameters(),
                                           lr=FLAGS.lr,
                                           momentum=FLAGS.momentum,
                                           weight_decay=5e-4))
        lr_scheduler = context.getattr_or(
            'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler(
                optimizer,
                scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
                scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
                scheduler_divide_every_n_epochs=getattr(
                    FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
                num_steps_per_epoch=num_training_steps_per_epoch,
                summary_writer=writer if xm.is_master_ordinal() else None))
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())
            if lr_scheduler:
                lr_scheduler.step()

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None
    num_devices = len(
        xm.xla_replication_devices(devices)) if len(devices) > 1 else 1
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         num_devices)
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = mean(accuracies)
        print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
        global_step = (epoch - 1) * num_training_steps_per_epoch
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         global_step)
        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy
def train_imagenet():
    print("==> Preparing data..")
    img_dim = get_model_property("img_dim")
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(
                torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            ),
            sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size(),
        )
        test_loader = xu.SampleGenerator(
            data=(
                torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64),
            ),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size(),
        )
    else:
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, "train"),
            transforms.Compose(
                [
                    transforms.RandomResizedCrop(img_dim),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ]
            ),
        )
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, "val"),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose(
                [
                    transforms.Resize(resize_dim),
                    transforms.CenterCrop(img_dim),
                    transforms.ToTensor(),
                    normalize,
                ]
            ),
        )
        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True,
            )
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False,
            )
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=True,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers,
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            sampler=test_sampler,
            drop_last=False,
            shuffle=False,
            num_workers=FLAGS.num_workers,
        )

    torch.manual_seed(42)

    devices = (
        xm.get_xla_supported_devices(max_devices=FLAGS.num_cores)
        if FLAGS.num_cores != 0
        else ["cpu"]
    )
    print("use tpu devices", devices)
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    torchvision_model = get_model_property("model_fn")
    model_parallel = dp.DataParallel(torchvision_model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            "optimizer",
            lambda: optim.SGD(
                model.parameters(),
                lr=FLAGS.lr,
                momentum=FLAGS.momentum,
                weight_decay=1e-4,
            ),
        )
        lr_scheduler = context.getattr_or(
            "lr_scheduler",
            lambda: schedulers.wrap_optimizer_with_scheduler(
                optimizer,
                scheduler_type=getattr(FLAGS, "lr_scheduler_type", None),
                scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None),
                scheduler_divide_every_n_epochs=getattr(
                    FLAGS, "lr_scheduler_divide_every_n_epochs", None
                ),
                num_steps_per_epoch=num_training_steps_per_epoch,
                summary_writer=writer if xm.is_master_ordinal() else None,
            ),
        )
        tracker = xm.RateTracker()
        model.train()
        total_samples = 0
        correct = 0
        top5_accuracys = 0
        losses = 0
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]
            top5_accuracys += topk_accuracy(output, target, topk=5)
            loss = loss_fn(output, target)
            loss.backward()
            losses += loss.item()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                print(
                    "[{}]({}) Loss={:.5f} Top-1 ACC = {:.2f} Rate={:.2f} GlobalRate={:.2f} Time={}".format(
                        str(device),
                        x,
                        loss.item(),
                        (100.0 * correct / total_samples).item(),
                        tracker.rate(),
                        tracker.global_rate(),
                        time.asctime(),
                    )
                )

            if lr_scheduler:
                lr_scheduler.step()
        return (
            losses / (x + 1),
            (100.0 * correct / total_samples).item(),
            (top5_accuracys / (x + 1)).item(),
        )

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        top5_accuracys = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]
            top5_accuracys += topk_accuracy(output, target, topk=5).item()

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy, top5_accuracys

    accuracy = 0.0
    writer = SummaryWriter(FLAGS.logdir) if FLAGS.logdir else None
    num_devices = len(xm.xla_replication_devices(devices)) if len(devices) > 1 else 1
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * num_devices)
    max_accuracy = 0.0
    print("train_loader_len", len(train_loader), num_training_steps_per_epoch)
    print("test_loader_len", len(test_loader))
    for epoch in range(1, FLAGS.num_epochs + 1):
        global_step = (epoch - 1) * num_training_steps_per_epoch
        # Train evaluate
        metrics = model_parallel(train_loop_fn, train_loader)
        losses, accuracies, top5_accuracys = zip(*metrics)
        loss = mean(losses)
        accuracy = mean(accuracies)
        top5_accuracy = mean(top5_accuracys)
        print(
            "Epoch: {} (Train), Loss {}, Mean Top-1 Accuracy: {:.2f} Top-5 accuracy: {}".format(
                epoch, loss, accuracy, top5_accuracy
            )
        )
        test_utils.add_scalar_to_summary(writer, "Loss/train", loss, global_step)
        test_utils.add_scalar_to_summary(
            writer, "Top-1 Accuracy/train", accuracy, global_step
        )
        test_utils.add_scalar_to_summary(
            writer, "Top-5 Accuracy/train", top5_accuracy, global_step
        )

        # Test evaluate
        metrics = model_parallel(test_loop_fn, test_loader)
        accuracies, top5_accuracys = zip(*metrics)
        top5_accuracys = sum(top5_accuracys)
        top5_accuracy = top5_accuracys / len(test_loader)
        accuracy = mean(accuracies)
        print(
            "Epoch: {} (Valid), Mean Top-1 Accuracy: {:.2f} Top-5 accuracy: {}".format(
                epoch, accuracy, top5_accuracy
            )
        )
        test_utils.add_scalar_to_summary(
            writer, "Top-1 Accuracy/test", accuracy, global_step
        )
        test_utils.add_scalar_to_summary(
            writer, "Top-5 Accuracy/test", top5_accuracy, global_step
        )
        if FLAGS.metrics_debug:
            print(met.metrics_report())
        if accuracy > max_accuracy:
            max_accuracy = max(accuracy, max_accuracy)
    torch.save(
        model_parallel.models[0].to("cpu").state_dict(),
        f"./reports/resnet50_model-{epoch}.pt",
    )
    model_parallel.models[0].to(devices[0])

    test_utils.close_summary_writer(writer)
    print("Max Accuracy: {:.2f}%".format(accuracy))
    return max_accuracy
def train_imagenet():
    torch.manual_seed(42)

    device = xm.xla_device()
    # model = get_model_property('model_fn')().to(device)
    model = create_model(
        FLAGS.model,
        pretrained=FLAGS.pretrained,
        num_classes=FLAGS.num_classes,
        drop_rate=FLAGS.drop,
        global_pool=FLAGS.gp,
        bn_tf=FLAGS.bn_tf,
        bn_momentum=FLAGS.bn_momentum,
        bn_eps=FLAGS.bn_eps,
        drop_connect_rate=0.2,
        checkpoint_path=FLAGS.initial_checkpoint,
        args = FLAGS).to(device)
    model_ema=None
    if FLAGS.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        # import pdb; pdb.set_trace()
        model_e = create_model(
            FLAGS.model,
            pretrained=FLAGS.pretrained,
            num_classes=FLAGS.num_classes,
            drop_rate=FLAGS.drop,
            global_pool=FLAGS.gp,
            bn_tf=FLAGS.bn_tf,
            bn_momentum=FLAGS.bn_momentum,
            bn_eps=FLAGS.bn_eps,
            drop_connect_rate=0.2,
            checkpoint_path=FLAGS.initial_checkpoint,
            args = FLAGS).to(device)
        model_ema = ModelEma(
            model_e,
            decay=FLAGS.model_ema_decay,
            device='cpu' if FLAGS.model_ema_force_cpu else '',
            resume=FLAGS.resume)
    print('==> Preparing data..')
    img_dim = 224
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                    torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                    torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    # else:
    #     normalize = transforms.Normalize(
    #         mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    #     train_dataset = torchvision.datasets.ImageFolder(
    #         os.path.join(FLAGS.data, 'train'),
    #         transforms.Compose([
    #             transforms.RandomResizedCrop(img_dim),
    #             transforms.RandomHorizontalFlip(),
    #             transforms.ToTensor(),
    #             normalize,
    #         ]))
    #     train_dataset_len = len(train_dataset.imgs)
    #     resize_dim = max(img_dim, 256)
    #     test_dataset = torchvision.datasets.ImageFolder(
    #         os.path.join(FLAGS.data, 'val'),
    #         # Matches Torchvision's eval transforms except Torchvision uses size
    #         # 256 resize for all models both here and in the train loader. Their
    #         # version crashes during training on 299x299 images, e.g. inception.
    #         transforms.Compose([
    #             transforms.Resize(resize_dim),
    #             transforms.CenterCrop(img_dim),
    #             transforms.ToTensor(),
    #             normalize,
    #         ]))

    #     train_sampler = None
    #     if xm.xrt_world_size() > 1:
    #         train_sampler = torch.utils.data.distributed.DistributedSampler(
    #             train_dataset,
    #             num_replicas=xm.xrt_world_size(),
    #             rank=xm.get_ordinal(),
    #             shuffle=True)
    #     train_loader = torch.utils.data.DataLoader(
    #         train_dataset,
    #         batch_size=FLAGS.batch_size,
    #         sampler=train_sampler,
    #         shuffle=False if train_sampler else True,
    #         num_workers=FLAGS.workers)
    #     test_loader = torch.utils.data.DataLoader(
    #         test_dataset,
    #         batch_size=FLAGS.batch_size,
    #         shuffle=False,
    #         num_workers=FLAGS.workers)
    else:
        train_dir = os.path.join(FLAGS.data, 'train')
        data_config = resolve_data_config(model, FLAGS, verbose=FLAGS.local_rank == 0)
        dataset_train = Dataset(train_dir)

        collate_fn = None
        if not FLAGS.no_prefetcher and FLAGS.mixup > 0:
            collate_fn = FastCollateMixup(FLAGS.mixup, FLAGS.smoothing, FLAGS.num_classes)
        train_loader = create_loader(
            dataset_train,
            input_size=data_config['input_size'],
            batch_size=FLAGS.batch_size,
            is_training=True,
            use_prefetcher=not FLAGS.no_prefetcher,
            rand_erase_prob=FLAGS.reprob,
            rand_erase_mode=FLAGS.remode,
            interpolation='bicubic',  # FIXME cleanly resolve this? data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=FLAGS.workers,
            distributed=FLAGS.distributed,
            collate_fn=collate_fn,
            use_auto_aug=FLAGS.auto_augment,
            use_mixcut=FLAGS.mixcut,
        )

        eval_dir = os.path.join(FLAGS.data, 'val')
        train_dataset_len = len(train_loader)
        if not os.path.isdir(eval_dir):
            logging.error('Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
        dataset_eval = Dataset(eval_dir)

        test_loader = create_loader(
            dataset_eval,
            input_size=data_config['input_size'],
            batch_size = FLAGS.batch_size,
            is_training=False,
            use_prefetcher=FLAGS.prefetcher,
            interpolation=data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=FLAGS.workers,
            distributed=FLAGS.distributed,
        )


    writer = None
    start_epoch = 0
    if FLAGS.output and xm.is_master_ordinal():
        writer = SummaryWriter(log_dir=FLAGS.output)
    optimizer = create_optimizer(flags, model)
    lr_scheduler, num_epochs = create_scheduler(flags, optimizer)
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
    # optimizer = optim.SGD(
    #     model.parameters(),
    #     lr=FLAGS.lr,
    #     momentum=FLAGS.momentum,
    #     weight_decay=5e-4)
    num_training_steps_per_epoch = train_dataset_len // (
        FLAGS.batch_size * xm.xrt_world_size())
        
    lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
        optimizer,
        scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
        scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
        scheduler_divide_every_n_epochs=getattr(
            FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
        num_steps_per_epoch=num_training_steps_per_epoch,
        summary_writer=writer)
    train_loss_fn = LabelSmoothingCrossEntropy(smoothing=flags.smoothing)
    validate_loss_fn = nn.CrossEntropyLoss()
    # loss_fn = nn.CrossEntropyLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = train_loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if model_ema is not None:
                model_ema.update(model)
            if lr_scheduler:
                lr_scheduler.step()
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(), tracker.rate(),
                                            tracker.global_rate())

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy
    def test_loop_fn_ema(loader):
            total_samples = 0
            correct = 0
            model_ema.eval()
            for x, (data, target) in loader:
                output = model_ema(data)
                pred = output.max(1, keepdim=True)[1]
                correct += pred.eq(target.view_as(pred)).sum().item()
                total_samples += data.size()[0]

            accuracy = 100.0 * correct / total_samples
            test_utils.print_test_update(device, accuracy)
            return accuracy
    accuracy = 0.0
    for epoch in range(1, FLAGS.epochs + 1):
        para_loader = dp.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))

        para_loader = dp.ParallelLoader(test_loader, [device])
        accuracy = test_loop_fn(para_loader.per_device_loader(device))
        print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
        if model_ema is not None:
            accuracy = test_loop_fn_ema(para_loader.per_device_loader(device))
            print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch)

        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy
def train_mnist():
    torch.manual_seed(1)

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=60000 // FLAGS.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        train_dataset = datasets.MNIST(FLAGS.datadir,
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        test_dataset = datasets.MNIST(FLAGS.datadir,
                                      train=False,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    # Scale learning rate to num cores
    lr = FLAGS.lr * xm.xrt_world_size()

    device = xm.xla_device()
    model = MNIST().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum)
    loss_fn = nn.NLLLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print("Finished training epoch {}".format(epoch))

        para_loader = pl.ParallelLoader(test_loader, [device])
        accuracy = test_loop_fn(para_loader.per_device_loader(device))
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         epoch)
        if FLAGS.metrics_debug:
            print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    return accuracy
Example #7
0
def train_cifar():
    print('==> Preparing data..')

    if FLAGS.fake_data:
        train_dataset_len = 50000  # Number of example in CIFAR train set.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, 32, 32),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, 32, 32),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                                     train=True,
                                                     download=True,
                                                     transform=transform_train)
        train_dataset_len = len(train_dataset)
        test_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                                    train=False,
                                                    download=True,
                                                    transform=transform_test)
        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            sampler=test_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    model = torchvision.models.resnet18 if FLAGS.use_torchvision else ResNet18
    model_parallel = dp.DataParallel(model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(model.parameters(),
                                           lr=FLAGS.lr,
                                           momentum=FLAGS.momentum,
                                           weight_decay=5e-4))
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    writer = test_utils.get_summary_writer(FLAGS.logdir)
    num_devices = len(
        xm.xla_replication_devices(devices)) if len(devices) > 1 else 1
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         num_devices)
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = mean(accuracies)
        print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy))
        global_step = (epoch - 1) * num_training_steps_per_epoch
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         global_step)
        if FLAGS.metrics_debug:
            print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    return accuracy
Example #8
0
def train_mnist():
    torch.manual_seed(1)

    if FLAGS.fake_data:
        train_dataset_len = 60000  # Number of images in MNIST dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        train_dataset = datasets.MNIST(FLAGS.datadir,
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        train_dataset_len = len(train_dataset)
        test_dataset = datasets.MNIST(FLAGS.datadir,
                                      train=False,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            sampler=test_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Scale learning rate to num cores
    lr = FLAGS.lr * max(len(devices), 1)
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    model_parallel = dp.DataParallel(MNIST, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.NLLLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(
                model.parameters(), lr=lr, momentum=FLAGS.momentum))
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    writer = test_utils.get_summary_writer(FLAGS.logdir)
    num_devices = len(
        xm.xla_replication_devices(devices)) if len(devices) > 1 else 1
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         num_devices)
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = mean(accuracies)
        print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
        global_step = (epoch - 1) * num_training_steps_per_epoch
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         global_step)
        if FLAGS.metrics_debug:
            print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    return accuracy
def train_imagenet():
    print("==> Preparing data..")
    img_dim = get_model_property("img_dim")
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(
                torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            ),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size(),
        )
        test_loader = xu.SampleGenerator(
            data=(
                torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64),
            ),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size(),
        )
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, "train"),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]),
        )
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, "val"),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]),
        )

        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True,
            )
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers,
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            shuffle=False,
            num_workers=FLAGS.num_workers,
        )

    torch.manual_seed(42)

    device = xm.xla_device()
    model = get_model_property("model_fn")()
    writer = None
    if FLAGS.logdir and xm.is_master_ordinal():
        writer = SummaryWriter(log_dir=FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(),
                          lr=FLAGS.lr,
                          momentum=FLAGS.momentum,
                          weight_decay=1e-4)
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         xm.xrt_world_size())
    lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
        optimizer,
        scheduler_type=getattr(FLAGS, "lr_scheduler_type", None),
        scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None),
        scheduler_divide_every_n_epochs=getattr(
            FLAGS, "lr_scheduler_divide_every_n_epochs", None),
        num_steps_per_epoch=num_training_steps_per_epoch,
        summary_writer=writer,
    )
    start_epoch = 0
    if FLAGS.warm_start:
        checkpoint = torch.load(f"./reports/resnet152_model-26.pt")
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler._step_count = checkpoint["step"]
        start_epoch = checkpoint["epoch"]
    model.to(device)
    loss_fn = nn.CrossEntropyLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        total_samples = 0
        correct = 0
        top5_accuracys = 0
        losses = 0
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            losses += loss.item()
            total_samples += data.size()[0]
            top5_accuracys += topk_accuracy(output, target, topk=5).item()
            if lr_scheduler:
                lr_scheduler.step()
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())
        return (
            losses / (x + 1),
            (100.0 * correct / total_samples),
            (top5_accuracys / (x + 1)),
        )

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        top5_accuracys = 0
        model.eval()
        for x, (data, target) in enumerate(loader):
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]
            top5_accuracys += topk_accuracy(output, target, topk=5).item()

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy, top5_accuracys / (x + 1)

    accuracy = 0.0
    max_accuracy = 0.0
    start = time.time()
    for epoch in range(start_epoch, FLAGS.num_epochs + 1):
        epoch_start = time.time()
        para_loader = pl.ParallelLoader(train_loader, [device],
                                        loader_prefetch_size=32,
                                        device_prefetch_size=8)
        loss, accuracy, top5_accuracy = train_loop_fn(
            para_loader.per_device_loader(device))
        if xm.is_master_ordinal():
            print(
                "Finished training epoch {}, duration_time {} sec, total duration_time {} sec"
                .format(epoch,
                        time.time() - epoch_start,
                        time.time() - start))
            print(
                "Epoch: {} (Train), Loss {}, Top-1 Accuracy: {:.2f} Top-5 accuracy: {}"
                .format(epoch, loss, accuracy, top5_accuracy))
            test_utils.add_scalar_to_summary(writer, "Loss/train", loss, epoch)
            test_utils.add_scalar_to_summary(writer, "Top-1 Accuracy/train",
                                             accuracy, epoch)
            test_utils.add_scalar_to_summary(writer, "Top-5 Accuracy/train",
                                             top5_accuracy, epoch)
        para_loader = pl.ParallelLoader(test_loader, [device])
        accuracy, top5_accuracy = test_loop_fn(
            para_loader.per_device_loader(device))
        if xm.is_master_ordinal():
            print(
                "Epoch: {} (Valid), Top-1 Accuracy: {:.2f} Top-5 accuracy: {}".
                format(epoch, accuracy, top5_accuracy))
            test_utils.add_scalar_to_summary(writer, "Top-1 Accuracy/test",
                                             accuracy, epoch)
            test_utils.add_scalar_to_summary(writer, "Top-5 Accuracy/test",
                                             top5_accuracy, epoch)
        if FLAGS.metrics_debug:
            print(met.metrics_report())
        if accuracy > max_accuracy:
            max_accuracy = max(accuracy, max_accuracy)
            xm.save(
                {
                    "epoch": epoch,
                    "step": lr_scheduler._step_count,
                    "state_dict": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                },
                f"./reports/{FLAGS.model}_model-{epoch}.pt",
                master_only=True,
            )
            if writer is not None:
                writer.flush()

    return accuracy
Example #10
0
def train_imagenet():
    print('==> Preparing data..')
    img_dim = get_model_property('img_dim')
    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=1200000 // FLAGS.batch_size)
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size)
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            shuffle=True,
            num_workers=FLAGS.num_workers)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]))
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    torchvision_model = get_model_property('model_fn')
    model_parallel = dp.DataParallel(torchvision_model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(model.parameters(),
                                           lr=FLAGS.lr,
                                           momentum=FLAGS.momentum,
                                           weight_decay=5e-4))
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(
                    device, x, loss.item(), tracker.rate()))

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        print('[{}] Accuracy={:.2f}%'.format(device,
                                             100.0 * correct / total_samples))
        return correct / total_samples

    accuracy = 0.0
    writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = mean(accuracies)
        print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy))
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         epoch)
        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy * 100.0