예제 #1
0
파일: build.py 프로젝트: naykun/mmf
def build_dataloader_and_sampler(
    dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig
) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
    """Builds and returns a dataloader along with its sample

    Args:
        dataset_instance (torch.utils.data.Dataset): Instance of dataset for which
            dataloader has to be created
        datamodule_config (omegaconf.DictConfig): Datamodule configuration; required
            for infering params for dataloader

    Returns:
        Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
            Tuple of Dataloader and Sampler instance
    """
    from mmf.common.batch_collator import BatchCollator

    training_config = get_global_config("training")
    # Support params coming in from dataloader params
    other_args = {
        "num_workers": datamodule_config.get(
            "num_workers", training_config.get("num_workers", 4)
        ),
        "pin_memory": datamodule_config.get(
            "pin_memory", training_config.get("pin_memory", False)
        ),
        "shuffle": datamodule_config.get("shuffle", None),
        "batch_size": datamodule_config.get("batch_size", None),
    }

    # IterableDataset returns batches directly, so no need to add Sampler
    # or batch size as user is expected to control those. This is a fine
    # assumption for now to not support single item based IterableDataset
    # as it will add unnecessary complexity and config parameters
    # to the codebase
    if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
        other_args = _add_extra_args_for_dataloader(dataset_instance, other_args)
    else:
        other_args.pop("shuffle")

    loader = torch.utils.data.DataLoader(
        dataset=dataset_instance,
        collate_fn=BatchCollator(
            dataset_instance.dataset_name, dataset_instance.dataset_type
        ),
        drop_last=is_xla(),  # see also MultiDatasetLoader.__len__
        **other_args,
    )

    if is_xla():
        device = xm.xla_device()
        loader = xla_pl.MpDeviceLoader(loader, device)

    if other_args["num_workers"] >= 0:
        # Suppress leaking semaphore warning
        os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    loader.dataset_type = dataset_instance.dataset_type

    return loader, other_args.get("sampler", None)
예제 #2
0
def main(args):
    torch.manual_seed(args.seed)
    device = xm.xla_device()
    loader_kwargs = {
        'num_workers': args.num_workers,
        'batch_size': args.batch_size,
    }

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    train_dataset = datasets.MNIST('data',
                                   train=True,
                                   download=True,
                                   transform=transform)
    val_dataset = datasets.MNIST('data', train=False, transform=transform)

    if xm.xrt_world_size() > 1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True)
    else:
        train_sampler = torch.utils.data.RandomSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               sampler=train_sampler,
                                               **loader_kwargs)
    train_loader = pl.MpDeviceLoader(train_loader, device)

    test_loader = torch.utils.data.DataLoader(val_dataset, **loader_kwargs)
    test_loader = pl.MpDeviceLoader(test_loader, device)

    model = Net().to(device)

    # Scale learning rate to world size
    lr = args.learning_rate * xm.xrt_world_size()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.num_epochs + 1):
        train_one_epoch(args, model, device, train_loader, optimizer, epoch)
        validate(model, device, test_loader)
        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")
예제 #3
0
def build_dataloader_and_sampler(
    dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig
) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
    """Builds and returns a dataloader along with its sample

    Args:
        dataset_instance (torch.utils.data.Dataset): Instance of dataset for which
            dataloader has to be created
        datamodule_config (omegaconf.DictConfig): Datamodule configuration; required
            for infering params for dataloader

    Returns:
        Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
            Tuple of Dataloader and Sampler instance
    """
    from mmf.common.batch_collator import BatchCollator

    training_config = get_global_config("training")
    # Support params coming in from dataloader params
    other_args = {
        "num_workers": datamodule_config.get(
            "num_workers", training_config.get("num_workers", 4)
        ),
        "pin_memory": datamodule_config.get(
            "pin_memory", training_config.get("pin_memory", False)
        ),
        "shuffle": datamodule_config.get("shuffle", None),
        "batch_size": datamodule_config.get("batch_size", None),
    }
    if version.parse(torch.__version__) >= version.parse("1.8"):
        # only use persistent workers in PyTorch 1.8 or higher
        # (PyTorch 1.7 also has this option but doesn't support it correctly due to
        # https://github.com/pytorch/pytorch/issues/48370)
        other_args["persistent_workers"] = (
            datamodule_config.get(
                "persistent_workers", training_config.get("persistent_workers", True)
            ),
        )
        if other_args["persistent_workers"] and other_args["num_workers"] == 0:
            logger.warning(
                "persistent_workers cannot be used together with num_workers == 0; "
                "setting persistent_workers to False"
            )
            other_args["persistent_workers"] = False

    # IterableDataset returns batches directly, so no need to add Sampler
    # or batch size as user is expected to control those. This is a fine
    # assumption for now to not support single item based IterableDataset
    # as it will add unnecessary complexity and config parameters
    # to the codebase
    if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
        other_args = _add_extra_args_for_dataloader(dataset_instance, other_args)
    else:
        other_args.pop("shuffle")

    # Set drop_last=True when using XLA to have constant batch size.
    # In this case we also need to set drop_last=True in DistributedSampler.
    loader = torch.utils.data.DataLoader(
        dataset=dataset_instance,
        collate_fn=BatchCollator(
            dataset_instance.dataset_name, dataset_instance.dataset_type
        ),
        drop_last=is_xla(),  # see also MultiDatasetLoader.__len__
        **other_args,
    )

    if is_xla():
        device = xm.xla_device()
        loader = xla_pl.MpDeviceLoader(loader, device)

    if other_args["num_workers"] >= 0:
        # Suppress leaking semaphore warning
        os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    loader.dataset_type = dataset_instance.dataset_type

    return loader, other_args.get("sampler", None)
예제 #4
0
def train_bert(dataset_path, xla_enabled, amp_enabled):
    max_seq_length = 128
    batch_size = 16
    num_epochs = 1
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # model = BERT()
    model = BERTdownsized()
    dat = pd.read_csv(dataset_path)
    print(dat.head)

    X = dat['review']
    y = dat['sentiment']

    X_train, X_test, y_train, y_test = train_test_split(X,
                                                        y,
                                                        test_size=0.10,
                                                        random_state=42)
    X_train = X_train.values.tolist()
    X_test = X_test.values.tolist()

    y_train = pd.get_dummies(y_train).values.tolist()
    y_test = pd.get_dummies(y_test).values.tolist()

    train_lists = [X_train, y_train]
    test_lists = [X_test, y_test]

    training_dataset = text_dataset(x_y_list=train_lists,
                                    max_seq_length=max_seq_length,
                                    tokenizer=tokenizer)

    test_dataset = text_dataset(x_y_list=test_lists,
                                max_seq_length=max_seq_length,
                                tokenizer=tokenizer)

    dataloaders_dict = {
        'train':
        torch.utils.data.DataLoader(training_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=0),
        'val':
        torch.utils.data.DataLoader(test_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=0)
    }
    dataset_sizes = {'train': len(train_lists[0]), 'val': len(test_lists[0])}

    if xla_enabled:
        device = xm.xla_device()
    else:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    print(device)
    lrlast = 1e-3
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=lrlast)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
    print('==> Starting Training')
    if amp_enabled:
        autocast, scaler = get_autocast_and_scaler(xla_enabled)

    if xla_enabled:
        import torch_xla.distributed.parallel_loader as pl
        server = xp.start_server(port_number)
        train_device_loader = pl.MpDeviceLoader(dataloaders_dict['train'],
                                                device)
        # train_device_loader = dataloaders_dict['train']
    else:
        train_device_loader = dataloaders_dict['train']

    if dlprof_enabled and not xla_enabled and False:
        with torch.autograd.profiler.emit_nvtx():
            for epoch in range(num_epochs):
                epoch_time = time.time()
                # tracker = xm.RateTracker()
                print('Epoch {}/{}'.format(epoch, num_epochs - 1))
                print('-' * 10)
                model.train()  # Set model to training mode
                # Iterate over data.
                for step, (inputs,
                           sentiment) in enumerate(train_device_loader):
                    tracker = xm.RateTracker(
                    )  # Placing the tracker here frees it of I/O time.
                    if not xla_enabled:  # This section is not necessary (but doesn't cause any performance problems) for XLA
                        inputs = inputs.to(device)
                        sentiment = sentiment.to(device)
                    optimizer.zero_grad()
                    if amp_enabled:
                        loss, optimizer = loop_with_amp(
                            model, inputs, sentiment, optimizer, xla_enabled,
                            autocast, scaler)
                    else:
                        loss, optimizer = loop_without_amp(
                            model, inputs, sentiment, optimizer, xla_enabled)
                    tracker.add(inputs.size(0))
                    _train_update(device, step, loss, tracker, epoch, None)

                time_elapsed = time.time() - epoch_time
                print(
                    f'Epoch complete in {time_elapsed // 60}m {time_elapsed % 60}s'
                )
    else:
        for epoch in range(num_epochs):
            epoch_time = time.time()
            # tracker = xm.RateTracker()
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
            print('-' * 10)
            model.train()  # Set model to training mode
            # Iterate over data.
            if cpu_mem_usage:
                import resource
                print(
                    f" CPU Usage Before: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}"
                )
            for step, (inputs, sentiment) in enumerate(train_device_loader):
                if step == 5:
                    training_started.set()
                tracker = xm.RateTracker(
                )  # Placing the tracker here frees it of I/O time.
                if not xla_enabled:  # This section is not necessary (but doesn't cause any performance problems) for XLA
                    inputs = inputs.to(device)
                    sentiment = sentiment.to(device)
                optimizer.zero_grad()
                if amp_enabled:
                    loss, optimizer = loop_with_amp(model, inputs, sentiment,
                                                    optimizer, xla_enabled,
                                                    autocast, scaler)
                else:
                    loss, optimizer = loop_without_amp(model, inputs,
                                                       sentiment, optimizer,
                                                       xla_enabled)
                tracker.add(inputs.size(0))
                _train_update(device, step, loss, tracker, epoch, None)

            time_elapsed = time.time() - epoch_time
            print(
                f'Epoch complete in {time_elapsed // 60}m {time_elapsed % 60}s'
            )
    if xla_enabled and debug_enabled:
        import torch_xla.debug.metrics as met
        print(met.metrics_report())
예제 #5
0
파일: build.py 프로젝트: zhang703652632/mmf
def build_dataloader_and_sampler(
    dataset_instance: mmf_typings.DatasetType,
    training_config: mmf_typings.DictConfig
) -> mmf_typings.DataLoaderAndSampler:
    """Builds and returns a dataloader along with its sample

    Args:
        dataset_instance (mmf_typings.DatasetType): Instance of dataset for which
            dataloader has to be created
        training_config (mmf_typings.DictConfig): Training configuration; required
            for infering params for dataloader

    Returns:
        mmf_typings.DataLoaderAndSampler: Tuple of Dataloader and Sampler instance
    """
    from mmf.common.batch_collator import BatchCollator

    num_workers = training_config.num_workers
    pin_memory = training_config.pin_memory

    other_args = {}

    # IterableDataset returns batches directly, so no need to add Sampler
    # or batch size as user is expected to control those. This is a fine
    # assumption for now to not support single item based IterableDataset
    # as it will add unnecessary complexity and config parameters
    # to the codebase
    if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
        other_args = _add_extra_args_for_dataloader(dataset_instance,
                                                    other_args)

    if is_xla():
        other_args["sampler"] = torch.utils.data.DistributedSampler(
            dataset_instance,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=other_args["shuffle"],
        )
        other_args.pop("shuffle")

    loader = torch.utils.data.DataLoader(
        dataset=dataset_instance,
        pin_memory=pin_memory,
        collate_fn=BatchCollator(dataset_instance.dataset_name,
                                 dataset_instance.dataset_type),
        num_workers=num_workers,
        drop_last=False,  # see also MultiDatasetLoader.__len__
        **other_args,
    )

    if is_xla():
        device = xm.xla_device()
        loader = pl.MpDeviceLoader(loader, device)

    if num_workers >= 0:
        # Suppress leaking semaphore warning
        os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    loader.dataset_type = dataset_instance.dataset_type

    return loader, other_args.get("sampler", None)
예제 #6
0
def train_imagenet():
    print('==> Preparing data..')
    img_dim = get_model_property('img_dim')
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler, test_sampler = None, None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            persistent_workers=True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            sampler=test_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            persistent_workers=True,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    device = xm.xla_device()
    model = get_model_property('model_fn')()
    # Wrap the model with FSDP
    # You may wrap all, a subset, or none of the sub-modules with inner FSDPs
    # - to implement ZeRO-2, wrap none of the sub-modules
    # - to implement ZeRO-3, wrap all of the sub-modules (nested FSDP)
    # - you may wrap sub-modules at different granularity (e.g. at each resnet
    #   stage or each residual block or each conv layer).
    fsdp_wrap = lambda m: FSDP(m.to(device),
                               compute_dtype=getattr(torch, FLAGS.compute_dtype
                                                     ),
                               fp32_reduce_scatter=FLAGS.fp32_reduce_scatter,
                               flatten_parameters=FLAGS.flatten_parameters)
    # Apply gradient checkpointing to sub-modules if specified
    grad_ckpt_wrap = checkpoint_module if FLAGS.use_gradient_checkpointing else (
        lambda x: x)
    if FLAGS.use_nested_fsdp:
        # Here we apply inner FSDP at the level of child modules for ZeRO-3, which
        # corresponds to different stages in resnet (i.e. Stage 1 to 5).
        for submodule_name, submodule in model.named_children():
            if sum(p.numel() for p in submodule.parameters()) == 0:
                # Skip those submodules without parameters (i.e. no need to shard them)
                continue
            # Note: wrap with `checkpoint_module` first BEFORE wrapping with FSDP
            m_fsdp = fsdp_wrap(grad_ckpt_wrap(getattr(model, submodule_name)))
            setattr(model, submodule_name, m_fsdp)
    # Always wrap the base model with an outer FSDP
    model = fsdp_wrap(model)

    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(),
                          lr=FLAGS.lr,
                          momentum=FLAGS.momentum,
                          weight_decay=1e-4)
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         xm.xrt_world_size())
    lr_scheduler = schedulers.WarmupAndExponentialDecayScheduler(
        optimizer,
        num_steps_per_epoch=num_training_steps_per_epoch,
        divide_every_n_epochs=FLAGS.lr_scheduler_divide_every_n_epochs,
        divisor=FLAGS.lr_scheduler_divisor,
        num_warmup_epochs=FLAGS.num_warmup_epochs,
        summary_writer=writer)
    loss_fn = nn.CrossEntropyLoss()

    def train_loop_fn(loader, epoch):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()  # do not reduce gradients on sharded params
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()
            if step % FLAGS.log_steps == 0:
                xm.add_step_closure(_train_update,
                                    args=(device, step, loss, tracker, epoch,
                                          writer))

    def test_loop_fn(loader, epoch):
        total_samples, correct = 0, 0
        model.eval()
        for step, (data, target) in enumerate(loader):
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]
            if step % FLAGS.log_steps == 0:
                xm.add_step_closure(test_utils.print_test_update,
                                    args=(device, None, epoch, step))
        accuracy = 100.0 * correct.item() / total_samples
        accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)
    accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        xm.master_print('Epoch {} train begin {}'.format(
            epoch, test_utils.now()))
        train_loop_fn(train_device_loader, epoch)
        xm.master_print('Epoch {} train end {}'.format(epoch,
                                                       test_utils.now()))
        run_eval = ((not FLAGS.test_only_at_end
                     and epoch % FLAGS.eval_interval == 0)
                    or epoch == FLAGS.num_epochs)
        if run_eval:
            accuracy = test_loop_fn(test_device_loader, epoch)
            xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
                epoch, test_utils.now(), accuracy))
            max_accuracy = max(accuracy, max_accuracy)
            test_utils.write_to_summary(
                writer,
                epoch,
                dict_to_write={'Accuracy/test': accuracy},
                write_xla_metrics=True)
        if FLAGS.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
    return max_accuracy
예제 #7
0
def training(rank, world_size, backend, config):
    # Specific xla
    print(xm.get_ordinal(), ": run with config:", config, "- backend=", backend)
    device = xm.xla_device()

    # Data preparation
    dataset = RndDataset(nb_samples=config["nb_samples"])

    # Specific xla
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(),
    )
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=int(config["batch_size"] / xm.xrt_world_size()),
        num_workers=1,
        sampler=train_sampler,
    )

    # Specific xla
    para_loader = pl.MpDeviceLoader(train_loader, device)

    # Model, criterion, optimizer setup
    model = wide_resnet50_2(num_classes=100).to(device)
    criterion = NLLLoss()
    optimizer = SGD(model.parameters(), lr=0.01)

    # Training loop log param
    log_interval = config["log_interval"]

    def _train_step(batch_idx, data, target):

        data = data
        target = target

        optimizer.zero_grad()
        output = model(data)
        # Add a softmax layer
        probabilities = torch.nn.functional.softmax(output, dim=0)

        loss_val = criterion(probabilities, target)
        loss_val.backward()
        xm.optimizer_step(optimizer)

        if batch_idx % log_interval == 0:
            print(
                "Process {}/{} Train Epoch: {} [{}/{}]\tLoss: {}".format(
                    xm.get_ordinal(),
                    xm.xrt_world_size(),
                    epoch,
                    batch_idx * len(data),
                    len(train_sampler),
                    loss_val.item(),
                )
            )
        return loss_val

    # Running _train_step for n_epochs
    n_epochs = 1
    for epoch in range(n_epochs):
        for batch_idx, (data, target) in enumerate(para_loader):
            _train_step(batch_idx, data, target)
def train_mnist(flags, training_started=None, dynamic_graph=False, fetch_often=False):
    torch.manual_seed(1)

    if flags.fake_data:
        train_loader = xu.SampleGenerator(
            data=(
                torch.zeros(flags.batch_size, 1, 28, 28),
                torch.zeros(flags.batch_size, dtype=torch.int64),
            ),
            sample_count=600000 // flags.batch_size // xm.xrt_world_size(),
        )
        test_loader = xu.SampleGenerator(
            data=(
                torch.zeros(flags.batch_size, 1, 28, 28),
                torch.zeros(flags.batch_size, dtype=torch.int64),
            ),
            sample_count=100000 // flags.batch_size // xm.xrt_world_size(),
        )
    else:
        train_dataset = datasets.MNIST(
            os.path.join(flags.datadir, str(xm.get_ordinal())),
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        )
        test_dataset = datasets.MNIST(
            os.path.join(flags.datadir, str(xm.get_ordinal())),
            train=False,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        )
        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True
            )
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=flags.batch_size,
            sampler=train_sampler,
            drop_last=flags.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=flags.num_workers,
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=flags.batch_size,
            drop_last=flags.drop_last,
            shuffle=False,
            num_workers=flags.num_workers,
        )

    # Scale learning rate to num cores
    lr = flags.lr * xm.xrt_world_size()

    device = xm.xla_device()
    model = MNIST().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(flags.logdir)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum)
    loss_fn = nn.NLLLoss()

    server = xp.start_server(flags.profiler_port)

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            if dynamic_graph:
                # testing purpose only: dynamic batch size and graph.
                index = max(-step, -flags.batch_size + 1)  # non-empty
                data, target = data[:-index, :, :, :], target[:-index]
            if step >= 15 and training_started:
                # testing purpose only: set event for synchronization.
                training_started.set()

            with xp.StepTrace("train_mnist", step_num=step):
                with xp.Trace("build_graph"):
                    optimizer.zero_grad()
                    output = model(data)
                    loss = loss_fn(output, target)
                    loss.backward()
                xm.optimizer_step(optimizer)
                if fetch_often:
                    # testing purpose only: fetch XLA tensors to CPU.
                    loss_i = loss.item()
                tracker.add(flags.batch_size)
                if step % flags.log_steps == 0:
                    xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer))

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            with xp.StepTrace("test_mnist"):
                output = model(data)
                pred = output.max(1, keepdim=True)[1]
                correct += pred.eq(target.view_as(pred)).sum()
                total_samples += data.size()[0]

        accuracy = 100.0 * correct.item() / total_samples
        accuracy = xm.mesh_reduce("test_accuracy", accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)
    accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, flags.num_epochs + 1):
        xm.master_print("Epoch {} train begin {}".format(epoch, test_utils.now()))
        train_loop_fn(train_device_loader)
        xm.master_print("Epoch {} train end {}".format(epoch, test_utils.now()))

        accuracy = test_loop_fn(test_device_loader)
        xm.master_print(
            "Epoch {} test end {}, Accuracy={:.2f}".format(epoch, test_utils.now(), accuracy)
        )
        max_accuracy = max(accuracy, max_accuracy)
        test_utils.write_to_summary(
            writer, epoch, dict_to_write={"Accuracy/test": accuracy}, write_xla_metrics=True
        )
        if flags.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print("Max Accuracy: {:.2f}%".format(max_accuracy))
    return max_accuracy
예제 #9
0
def train_imagenet():
    print("==> Preparing data..")
    img_dim = get_model_property("img_dim")
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(
                torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            ),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size(),
        )
        if FLAGS.validate:
            test_loader = xu.SampleGenerator(
                data=(
                    torch.zeros(FLAGS.test_set_batch_size, 3, img_dim,
                                img_dim),
                    torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64),
                ),
                sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size(),
            )
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, "train"),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]),
        )
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        if FLAGS.validate:
            test_dataset = torchvision.datasets.ImageFolder(
                os.path.join(FLAGS.datadir, "val"),
                # Matches Torchvision's eval transforms except Torchvision uses size
                # 256 resize for all models both here and in the train loader. Their
                # version crashes during training on 299x299 images, e.g. inception.
                transforms.Compose([
                    transforms.Resize(resize_dim),
                    transforms.CenterCrop(img_dim),
                    transforms.ToTensor(),
                    normalize,
                ]),
            )

        train_sampler, test_sampler = None, None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            if FLAGS.validate:
                test_sampler = torch.utils.data.distributed.DistributedSampler(
                    test_dataset,
                    num_replicas=xm.xrt_world_size(),
                    rank=xm.get_ordinal(),
                    shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers,
        )
        if FLAGS.validate:
            test_loader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=FLAGS.test_set_batch_size,
                sampler=test_sampler,
                drop_last=FLAGS.drop_last,
                shuffle=False,
                num_workers=FLAGS.num_workers,
            )

    device = xm.xla_device()
    model = get_model_property("model_fn")().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(),
                          lr=FLAGS.lr,
                          momentum=FLAGS.momentum,
                          weight_decay=1e-4)
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         xm.xrt_world_size())
    lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
        optimizer,
        scheduler_type=getattr(FLAGS, "lr_scheduler_type", None),
        scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None),
        scheduler_divide_every_n_epochs=getattr(
            FLAGS, "lr_scheduler_divide_every_n_epochs", None),
        num_steps_per_epoch=num_training_steps_per_epoch,
        summary_writer=writer,
    )
    loss_fn = nn.CrossEntropyLoss()
    scaler = GradScaler()

    def train_loop_fn(loader, epoch):
        if FLAGS.fine_grained_metrics:
            epoch_start_time = time.time()
            step_latency_tracker, bwd_latency_tracker, fwd_latency_tracker = [], [], []
        else:
            tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            if FLAGS.fine_grained_metrics:
                step_start_time = time.time()
            optimizer.zero_grad()
            if FLAGS.fine_grained_metrics:
                fwd_start_time = time.time()
            with autocast():
                output = model(data)
                loss = loss_fn(output, target)
            if FLAGS.fine_grained_metrics:
                fwd_end_time = time.time()
                fwd_latency = fwd_end_time - fwd_start_time

                bwd_start_time = time.time()
            scaler.scale(loss).backward()
            gradients = xm._fetch_gradients(optimizer)
            xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
            scaler.step(optimizer)
            scaler.update()
            xm.mark_step()
            if lr_scheduler:
                lr_scheduler.step()
            if FLAGS.fine_grained_metrics:
                bwd_end_time = time.time()
                bwd_latency = bwd_end_time - bwd_start_time

                step_latency = bwd_end_time - step_start_time
                step_latency_tracker.append(step_latency)
                bwd_latency_tracker.append(bwd_latency)
                fwd_latency_tracker.append(fwd_latency)
            else:
                tracker.add(FLAGS.batch_size)
            if step % FLAGS.log_steps == 0:
                if FLAGS.fine_grained_metrics:
                    print('FineGrainedMetrics :: Epoch={} Step={} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\
                                                epoch, step, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
                else:
                    # _train_update(device, step, loss, tracker, epoch, writer)
                    xm.add_step_closure(_train_update,
                                        args=(device, step, loss, tracker,
                                              epoch, writer))
        if FLAGS.fine_grained_metrics:
            epoch_end_time = time.time()
            epoch_latency = epoch_end_time - epoch_start_time
            print('FineGrainedMetrics :: Epoch={} Epoch(s)={:.} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\
                                            epoch, epoch_latency, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))

    def test_loop_fn(loader, epoch):
        total_samples, correct = 0, 0
        model.eval()
        for step, (data, target) in enumerate(loader):
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]
            if step % FLAGS.log_steps == 0:
                test_utils.print_test_update(device, None, epoch, step)
                # xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step))
        accuracy = 100.0 * correct.item() / total_samples
        accuracy = xm.mesh_reduce("test_accuracy", accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    if FLAGS.validate:
        test_device_loader = pl.MpDeviceLoader(test_loader, device)
        accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        xm.master_print("Epoch {} train begin {}".format(
            epoch, test_utils.now()))
        train_loop_fn(train_device_loader, epoch)
        xm.master_print("Epoch {} train end {}".format(epoch,
                                                       test_utils.now()))
        if FLAGS.validate:
            accuracy = test_loop_fn(test_device_loader, epoch)
            xm.master_print("Epoch {} test end {}, Accuracy={:.2f}".format(
                epoch, test_utils.now(), accuracy))
            max_accuracy = max(accuracy, max_accuracy)
            test_utils.write_to_summary(
                writer,
                epoch,
                dict_to_write={"Accuracy/test": accuracy},
                write_xla_metrics=True)
        if FLAGS.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    if FLAGS.validate:
        xm.master_print("Max Accuracy: {:.2f}%".format(max_accuracy))
    return max_accuracy if FLAGS.validate else None
예제 #10
0
    def fit(self, train_loader, validation_loader):

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ] 

        # Try use different LR for HEAD and EffNet
        # self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=config.GPU_LR)
        LR = self.config.TPU_LR
        if global_config.CONTINUE_TRAIN: # Continue training proc -> Hand-tune LR 
            LR = self.config.TPU_LR # [9e-4, 1e-3]
        self.optimizer = torch.optim.AdamW([
                    {'params': self.model.efn.parameters(),       'lr': LR[0]},
                    {'params': self.model.fc1.parameters(),       'lr': LR[1]},
                    {'params': self.model.bn1.parameters(),       'lr': LR[1]},
                    {'params': self.model.dense_out.parameters(), 'lr': LR[1]}
                    ])

        ############################################## 
        self.scheduler = self.config.SchedulerClass(self.optimizer, **self.config.scheduler_params)

        # num_train_steps = int(self.steps * (global_config.GPU_EPOCH))
        # self.scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
        #     self.optimizer,
        #     num_warmup_steps=int(num_train_steps * 0.05), # WARMUP_PROPORTION = 0.1 as default
        #     num_training_steps=num_train_steps,
        #     num_cycles=0.5
        # )
        ##############################################
        # DataLoader should init only once (outside the epoch loop) 
        train_device_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())
        if validation_loader == 1:
            pass
        else:
            val_device_loader   = pl.MpDeviceLoader(validation_loader, xm.xla_device())
        ############################################## 

        for e in range(self.config.TPU_EPOCH):

            ############## Training
            gc.collect()
            t = time.time()
            xm.master_print("---" * 31)
            summary_loss, final_scores = self.train_one_epoch(train_device_loader)

            effNet_lr = np.format_float_scientific(self.optimizer.param_groups[0]['lr'], unique=False, precision=1)
            head_lr   = np.format_float_scientific(self.optimizer.param_groups[1]['lr'], unique=False, precision=1) 
            self.log(f":::[Train RESULT]| Epoch: {str(self.epoch).rjust(2, ' ')} | Loss: {summary_loss.avg:.4f} | AUC: {final_scores.avg:.4f} | LR: {effNet_lr}/{head_lr} | Time: {int((time.time() - t)//60)}m")
            self.save(f'{self.base_dir}/last_ckpt.pt')

            ############## Validation
            gc.collect()
            t = time.time()
            # Skip Validation
            if validation_loader == 1:
                pass
            else:
                summary_loss, final_scores = self.validation(val_device_loader)

            self.log(f":::[Valid RESULT]| Epoch: {str(self.epoch).rjust(2, ' ')} | Loss: {summary_loss.avg:.4f} | AUC: {final_scores.avg:.4f} | LR: {effNet_lr}/{head_lr} | Time: {int((time.time() - t)//60)}m")

            if summary_loss.avg < self.best_summary_loss:
                self.best_summary_loss = summary_loss.avg
                self.model.eval()
                self.save(f'{self.base_dir}/{global_config.SAVED_NAME}_{str(self.epoch).zfill(3)}ep.pt')
                # keep only the best 3 checkpoints
                # for path in sorted(glob(f'{self.base_dir}/{global_config.SAVED_NAME}_*ep.pt'))[:-3]:
                #     os.remove(path)

            if self.config.validation_scheduler:
                try:
                    self.scheduler.step(metrics=summary_loss.avg)
                except:
                    self.scheduler.step()
                    
            self.epoch += 1
예제 #11
0
def train_bert(dataset_path, xla_enabled, amp_enabled):
    max_seq_length = 256
    batch_size = 32
    num_epochs = 25
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    model = BERT()
    dat = pd.read_csv(dataset_path)
    print(dat.head)

    X = dat['review']
    y = dat['sentiment']

    X_train, X_test, y_train, y_test = train_test_split(X,
                                                        y,
                                                        test_size=0.10,
                                                        random_state=42)
    X_train = X_train.values.tolist()
    X_test = X_test.values.tolist()

    y_train = pd.get_dummies(y_train).values.tolist()
    y_test = pd.get_dummies(y_test).values.tolist()

    train_lists = [X_train, y_train]
    test_lists = [X_test, y_test]

    training_dataset = text_dataset(x_y_list=train_lists,
                                    max_seq_length=max_seq_length,
                                    tokenizer=tokenizer)

    test_dataset = text_dataset(x_y_list=test_lists,
                                max_seq_length=max_seq_length,
                                tokenizer=tokenizer)

    dataloaders_dict = {
        'train':
        torch.utils.data.DataLoader(training_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=0),
        'val':
        torch.utils.data.DataLoader(test_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=0)
    }
    dataset_sizes = {'train': len(train_lists[0]), 'val': len(test_lists[0])}

    if xla_enabled:
        device = xm.xla_device()
    else:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    print(device)
    lrlast = 1e-3
    optimizer = optim.Adam(model.parameters(), lr=lrlast)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
    model = model.to(device)
    print('==> Starting Training')
    if amp_enabled:
        autocast, scaler = get_autocast_and_scaler(xla_enabled)

    train_device_loader = pl.MpDeviceLoader(dataloaders_dict['train'], device)

    for epoch in range(num_epochs):
        epoch_time = time.time()
        tracker = xm.RateTracker()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        model.train()  # Set model to training mode
        # Iterate over data.
        for step, (inputs, sentiment) in enumerate(train_device_loader):
            # import pdb;pdb.set_trace()
            # sentiment = torch.max(sentiment.float(), 1)[1]
            # inputs = inputs.to(device)
            # sentiment = sentiment.to(device)
            optimizer.zero_grad()
            if amp_enabled:
                loss, optimizer = loop_with_amp(model, inputs, sentiment,
                                                optimizer, xla_enabled,
                                                autocast, scaler)
            else:
                loss, optimizer = loop_without_amp(model, inputs, sentiment,
                                                   optimizer, xla_enabled)
            tracker.add(inputs.size(0))
            _train_update(device, step, loss, tracker, epoch, None)

        time_elapsed = time.time() - epoch_time
        print(f'Epoch complete in {time_elapsed // 60}m {time_elapsed % 60}s')
예제 #12
0
def train_eval_loop(
    model,
    loss,
    optimizer,
    scheduler,
    train_loader,
    test_loader,
    device,
    epochs,
    verbose,
    save,
    save_freq=None,
    save_path=None,
    epoch_offset=0,
    **kwargs,
):
    print_fn = print
    if device.type == "xla":
        import torch_xla.distributed.parallel_loader as pl
        import torch_xla.core.xla_model as xm

        print_fn = xm.master_print
        train_loader = pl.MpDeviceLoader(train_loader, device)
        test_loader = pl.MpDeviceLoader(test_loader, device)

    test_loss, accuracy1, accuracy5 = eval(model, loss, test_loader, device, verbose, 0)
    metric_dict = {
        "train_loss": 0,
        "test_loss": test_loss,
        "accuracy1": accuracy1,
        "accuracy5": accuracy5,
    }
    if save:
        checkpoint(
            model,
            optimizer,
            scheduler,
            0,
            0,
            save_path,
            verbose,
            metric_dict,
            tpu=(device.type == "xla"),
        )
    for epoch in tqdm(range(epoch_offset, epoch_offset + epochs)):
        train_loss = train(
            model,
            loss,
            optimizer,
            scheduler,
            train_loader,
            device,
            epoch,
            verbose,
            save,
            save_freq=save_freq,
            save_path=save_path,
            **kwargs,
        )
        test_loss, accuracy1, accuracy5 = eval(
            model, loss, test_loader, device, verbose, epoch + 1
        )
        metric_dict = {
            "train_loss": train_loss,
            "test_loss": test_loss,
            "accuracy1": accuracy1,
            "accuracy5": accuracy5,
        }
        curr_step = (epoch + 1) * kwargs.get("num_batches")
        if save:
            checkpoint(
                model,
                optimizer,
                scheduler,
                epoch,
                curr_step,
                save_path,
                verbose,
                metric_dict,
                tpu=(device.type == "xla"),
            )
        scheduler.step()
    if epochs > 0:
        print_fn(
            f"Final performance: "
            f"\tTrain Loss: {train_loss:.4f}"
            f"\tTest Loss: {test_loss:.4f}"
            f"\tAccuracy: {accuracy1:.2f}%"
        )
예제 #13
0
def _main_xla(index, args):
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

    alphabet = alphabet_factory()
    train_dataset, test_dataset = split_dataset(args, alphabet)
    collate_fn = collate_factory(model_length_function)
    if xm.xrt_world_size() > 1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True)
    else:
        train_sampler = torch.utils.data.RandomSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               sampler=train_sampler,
                                               num_workers=args.num_workers,
                                               collate_fn=collate_fn,
                                               drop_last=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              collate_fn=collate_fn,
                                              drop_last=True)

    # Scale learning rate to world size
    lr = args.learning_rate * xm.xrt_world_size()

    # Get loss function, optimizer, and model
    device = xm.xla_device()
    model = build_deepspeech(in_features=in_features,
                             num_classes=len(alphabet))
    model = model.to(device)
    optimizer = get_optimizer(args, model.parameters())
    criterion = nn.CTCLoss(blank=alphabet.mapping[alphabet.char_blank])
    decoder = GreedyDecoder()

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)

    class XLAProxyOptimizer:
        """
        XLA Proxy optimizer for compatibility with
        torch.Optimizer
        """
        def __init__(self, optimizer):
            self.optimizer = optimizer

        def zero_grad(self):
            self.optimizer.zero_grad()

        def step(self):
            xm.optimizer_step(self.optimizer)

    optimizer = XLAProxyOptimizer(optimizer)

    train_eval_fn(args.num_epochs, train_device_loader, test_device_loader,
                  optimizer, model, criterion, device, decoder, alphabet,
                  args.checkpoint)
예제 #14
0
def train_imagenet():
    print('==> Preparing data..')
    img_dim = get_model_property('img_dim')
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler, test_sampler = None, None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            sampler=test_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    device = xm.xla_device()
    model = get_model_property('model_fn')().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(),
                          lr=FLAGS.lr,
                          momentum=FLAGS.momentum,
                          weight_decay=1e-4)
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         xm.xrt_world_size())
    lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
        optimizer,
        scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
        scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
        scheduler_divide_every_n_epochs=getattr(
            FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
        num_steps_per_epoch=num_training_steps_per_epoch,
        summary_writer=writer)
    loss_fn = nn.CrossEntropyLoss()

    def train_loop_fn(loader, epoch):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()
            if step % FLAGS.log_steps == 0:
                xm.add_step_closure(_train_update,
                                    args=(device, step, loss, tracker, epoch,
                                          writer))

    def test_loop_fn(loader, epoch):
        total_samples, correct = 0, 0
        model.eval()
        for step, (data, target) in enumerate(loader):
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]
            if step % FLAGS.log_steps == 0:
                xm.add_step_closure(test_utils.print_test_update,
                                    args=(device, None, epoch, step))
        accuracy = 100.0 * correct.item() / total_samples
        accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)
    accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        xm.master_print('Epoch {} train begin {}'.format(
            epoch, test_utils.now()))
        train_loop_fn(train_device_loader, epoch)
        xm.master_print('Epoch {} train end {}'.format(epoch,
                                                       test_utils.now()))
        accuracy = test_loop_fn(test_device_loader, epoch)
        xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
            epoch, test_utils.now(), accuracy))
        max_accuracy = max(accuracy, max_accuracy)
        test_utils.write_to_summary(writer,
                                    epoch,
                                    dict_to_write={'Accuracy/test': accuracy},
                                    write_xla_metrics=True)
        if FLAGS.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
    return max_accuracy
예제 #15
0
def run_fold(fold):
    create_dirs()
    print_fn = print if not config.USE_TPU else xm.master_print
    print_fn(f"___________________________________________________")
    print_fn(f"Training Model:              {config.NET}")
    print_fn(f"Training Fold:               {fold}")
    print_fn(f"Image Dimensions:            {config.H}x{config.W}")
    print_fn(f"Mixed Precision Training:    {config.MIXED_PRECISION_TRAIN}")
    print_fn(f"Training Batch Size:         {config.TRAIN_BATCH_SIZE}")
    print_fn(f"Validation Batch Size:       {config.VALID_BATCH_SIZE}")
    print_fn(f"Accumulate Iteration:        {config.ACCUMULATE_ITERATION}")

    global net
    train_loader, valid_loader = get_loaders(fold)
    device = get_device(n=fold + 1)
    net = net.to(device)
    scaler = torch.cuda.amp.GradScaler(
    ) if not config.USE_TPU and config.MIXED_PRECISION_TRAIN else None
    loss_tr = get_train_criterion(device=device)
    loss_fn = get_valid_criterion(device=device)
    optimizer, scheduler = get_optimizer_and_scheduler(net=net,
                                                       dataloader=train_loader)

    gc.collect()

    for epoch in range(config.MAX_EPOCHS):
        epoch_start = time.time()

        if config.DO_FREEZE_BATCH_NORM and epoch < config.FREEZE_BN_EPOCHS:
            freeze_batchnorm_stats(net)

        train_mp_device_loader = pl.MpDeviceLoader(
            train_loader, device,
            fixed_batch_size=True) if config.USE_TPU else train_loader
        train_one_epoch(fold,
                        epoch,
                        net,
                        loss_tr,
                        optimizer,
                        train_mp_device_loader,
                        device,
                        scaler=scaler,
                        scheduler=scheduler,
                        schd_batch_update=config.SCHEDULER_BATCH_STEP)
        del train_mp_device_loader
        gc.collect()

        valid_mp_device_loader = pl.MpDeviceLoader(
            valid_loader, device,
            fixed_batch_size=True) if config.USE_TPU else valid_loader
        valid_one_epoch(fold,
                        epoch,
                        net,
                        loss_fn,
                        valid_mp_device_loader,
                        device,
                        scheduler=None,
                        schd_loss_update=False)
        del valid_mp_device_loader
        gc.collect()
        print_fn(
            f'[{fold}/{config.FOLDS - 1}][{epoch:>2d}/{config.MAX_EPOCHS - 1:>2d}] Time Taken for Epoch {epoch}: {time.time() - epoch_start} seconds |'
        )

        if config.USE_TPU:
            xm.save(
                net.state_dict(),
                os.path.join(
                    config.WEIGHTS_PATH,
                    f'{config.NET}/{config.NET}_fold_{fold}_{epoch}.bin'))
        else:
            torch.save(
                net.state_dict(),
                os.path.join(
                    config.WEIGHTS_PATH,
                    f'{config.NET}/{config.NET}_fold_{fold}_{epoch}.bin'))

    #torch.save(model.cnn_model.state_dict(),'{}/cnn_model_fold_{}_{}'.format(CFG['model_path'], fold, CFG['tag']))
    del net, optimizer, train_loader, valid_loader, scheduler
    torch.cuda.empty_cache()
    print_fn(f"___________________________________________________")
예제 #16
0
def train_mnist(flags, state_dict):
    if flags.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(flags.batch_size, 1, 28, 28),
                  torch.zeros(flags.batch_size, dtype=torch.int64)),
            sample_count=60000 // flags.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(flags.batch_size, 1, 28, 28),
                  torch.zeros(flags.batch_size, dtype=torch.int64)),
            sample_count=10000 // flags.batch_size // xm.xrt_world_size())
    else:
        train_dataset = datasets.MNIST(os.path.join(flags.datadir,
                                                    str(xm.get_ordinal())),
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        test_dataset = datasets.MNIST(os.path.join(flags.datadir,
                                                   str(xm.get_ordinal())),
                                      train=False,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=flags.batch_size,
            sampler=train_sampler,
            drop_last=flags.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=flags.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=flags.batch_size,
            drop_last=flags.drop_last,
            shuffle=False,
            num_workers=flags.num_workers)

    # Scale learning rate to num cores
    lr = flags.lr * xm.xrt_world_size()

    device = xm.xla_device()
    model = MNIST()
    model.load_state_dict(state_dict)
    model = model.to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(flags.logdir)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum)
    loss_fn = nn.NLLLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(flags.batch_size)
            if step % flags.log_steps == 0:
                xm.add_step_closure(_train_update,
                                    args=(device, step, loss, tracker, writer),
                                    run_async=FLAGS.async_closures)

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct.item() / total_samples
        # accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)
    accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, flags.num_epochs + 1):
        xm.master_print('Epoch {} train begin {}'.format(
            epoch, test_utils.now()))
        train_loop_fn(train_device_loader)
        xm.master_print('Epoch {} train end {}'.format(epoch,
                                                       test_utils.now()))

        accuracy = test_loop_fn(test_device_loader)
        xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
            epoch, test_utils.now(), accuracy))
        max_accuracy = max(accuracy, max_accuracy)
        test_utils.write_to_summary(writer,
                                    epoch,
                                    dict_to_write={'Accuracy/test': accuracy},
                                    write_xla_metrics=True)
        if flags.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
    return max_accuracy
예제 #17
0
def train_mnist(flags, **kwargs):
  torch.manual_seed(1)

  if flags.fake_data:
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(flags.batch_size, 1, 28,
                          28), torch.zeros(flags.batch_size,
                                           dtype=torch.int64)),
        sample_count=60000 // flags.batch_size // xm.xrt_world_size())
    test_loader = xu.SampleGenerator(
        data=(torch.zeros(flags.batch_size, 1, 28,
                          28), torch.zeros(flags.batch_size,
                                           dtype=torch.int64)),
        sample_count=10000 // flags.batch_size // xm.xrt_world_size())
  else:
    train_dataset = datasets.MNIST(
        os.path.join(flags.datadir, str(xm.get_ordinal())),
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.1307,), (0.3081,))]))
    test_dataset = datasets.MNIST(
        os.path.join(flags.datadir, str(xm.get_ordinal())),
        train=False,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.1307,), (0.3081,))]))
    train_sampler = None
    if xm.xrt_world_size() > 1:
      train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=flags.batch_size,
        sampler=train_sampler,
        drop_last=flags.drop_last,
        shuffle=False if train_sampler else True,
        num_workers=flags.num_workers)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=flags.batch_size,
        drop_last=flags.drop_last,
        shuffle=False,
        num_workers=flags.num_workers)

  # Scale learning rate to num cores
  lr = flags.lr * xm.xrt_world_size()

  device = xm.xla_device()
  model = MNIST()
  # Wrap the model with FSDP
  fsdp_wrap = lambda m: FSDP(
      m.to(device),
      compute_dtype=getattr(torch, flags.compute_dtype),
      fp32_reduce_scatter=flags.fp32_reduce_scatter,
      flatten_parameters=flags.flatten_parameters)
  # Apply gradient checkpointing to sub-modules if specified
  grad_ckpt_wrap = checkpoint_module if flags.use_gradient_checkpointing else (
      lambda x: x)
  if flags.use_nested_fsdp:
    # Wrap a few sub-modules with inner FSDP (to implement ZeRO-3)
    # Note: wrap with `checkpoint_module` first BEFORE wrapping with FSDP
    model.conv1 = fsdp_wrap(grad_ckpt_wrap(model.conv1))
    model.conv2 = fsdp_wrap(grad_ckpt_wrap(model.conv2))
    model.fc1 = fsdp_wrap(grad_ckpt_wrap(model.fc1))
    model.fc2 = fsdp_wrap(grad_ckpt_wrap(model.fc2))
  # Always wrap the base model with an outer FSDP
  model = fsdp_wrap(model)

  writer = None
  if xm.is_master_ordinal():
    writer = test_utils.get_summary_writer(flags.logdir)
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum)
  loss_fn = nn.NLLLoss()

  def train_loop_fn(model, loader):
    tracker = xm.RateTracker()
    model.train()
    for step, (data, target) in enumerate(loader):
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      optimizer.step()  # do not reduce gradients on sharded params
      tracker.add(flags.batch_size)
      if step % flags.log_steps == 0:
        xm.add_step_closure(
            _train_update,
            args=(device, step, loss, tracker, writer),
            run_async=FLAGS.async_closures)

  def test_loop_fn(model, loader):
    total_samples = 0
    correct = 0
    model.eval()
    for data, target in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct.item() / total_samples
    accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
    return accuracy

  train_device_loader = pl.MpDeviceLoader(train_loader, device)
  test_device_loader = pl.MpDeviceLoader(test_loader, device)
  accuracy, max_accuracy = 0.0, 0.0
  for epoch in range(1, flags.num_epochs + 1):
    xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
    train_loop_fn(model, train_device_loader)
    xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))

    accuracy = test_loop_fn(model, test_device_loader)
    xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
        epoch, test_utils.now(), accuracy))
    max_accuracy = max(accuracy, max_accuracy)
    test_utils.write_to_summary(
        writer,
        epoch,
        dict_to_write={'Accuracy/test': accuracy},
        write_xla_metrics=True)
    if flags.metrics_debug:
      xm.master_print(met.metrics_report())

  if flags.ckpt_consolidation:
    # Note: to run this test, all the model checkpoints needs to be
    # accessible from the master rank. Set --ckpt_prefix to a shared file
    # system (e.g. NFS) when running on a TPU pod.

    # Save the final model checkpoint
    rank = xm.get_ordinal()
    world_size = xm.xrt_world_size()
    ckpt_path = f'{flags.ckpt_prefix}_rank-{rank:08d}-of-{world_size:08d}.pth'
    ckpt = {
        'model': model.state_dict(),
        'shard_metadata': model.get_shard_metadata(),
        'optimizer': optimizer.state_dict(),  # not needed in ckpt consolidation
    }
    os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
    xm.save(ckpt, ckpt_path, master_only=False)
    print(f'checkpoint saved to {ckpt_path}\n', end='')

    # Consolidate the sharded model checkpoints and test its accuracy
    if xm.is_master_ordinal(local=False):
      consolidate_sharded_model_checkpoints(
          ckpt_prefix=flags.ckpt_prefix, ckpt_suffix="_rank-*-of-*.pth")
    xm.rendezvous('ckpt_consolidation')
    model = MNIST().to(device)
    ckpt_consolidated = torch.load(f'{flags.ckpt_prefix}_consolidated.pth')
    model.load_state_dict(ckpt_consolidated['model'])
    accuracy = test_loop_fn(model, test_device_loader)
    xm.master_print(
        f'Checkpoint consolidated, Accuracy={accuracy:.2f} '
        '(note: it can be slightly different from the final training accuracy '
        'due to non-sync BatchNorm2d in the model)')

  test_utils.close_summary_writer(writer)
  xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
  return max_accuracy