Exemplo n.º 1
0
    def start(self, action_name: str) -> None:
        if action_name in self.RECORD_FUNCTIONS:
            if not self._start_trace:
                self.server = xp.start_server(self.port)
                self._start_trace = True

            if action_name in self.STEP_FUNCTIONS:
                step = self._get_step_num(action_name)
                recording = xp.StepTrace(action_name, step_num=step)
            else:
                recording = xp.Trace(action_name)
            recording.__enter__()
            self._recording_map[action_name] = recording
Exemplo n.º 2
0
def train_bert(dataset_path, xla_enabled, amp_enabled):
    max_seq_length = 128
    batch_size = 16
    num_epochs = 1
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # model = BERT()
    model = BERTdownsized()
    dat = pd.read_csv(dataset_path)
    print(dat.head)

    X = dat['review']
    y = dat['sentiment']

    X_train, X_test, y_train, y_test = train_test_split(X,
                                                        y,
                                                        test_size=0.10,
                                                        random_state=42)
    X_train = X_train.values.tolist()
    X_test = X_test.values.tolist()

    y_train = pd.get_dummies(y_train).values.tolist()
    y_test = pd.get_dummies(y_test).values.tolist()

    train_lists = [X_train, y_train]
    test_lists = [X_test, y_test]

    training_dataset = text_dataset(x_y_list=train_lists,
                                    max_seq_length=max_seq_length,
                                    tokenizer=tokenizer)

    test_dataset = text_dataset(x_y_list=test_lists,
                                max_seq_length=max_seq_length,
                                tokenizer=tokenizer)

    dataloaders_dict = {
        'train':
        torch.utils.data.DataLoader(training_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=0),
        'val':
        torch.utils.data.DataLoader(test_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=0)
    }
    dataset_sizes = {'train': len(train_lists[0]), 'val': len(test_lists[0])}

    if xla_enabled:
        device = xm.xla_device()
    else:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    print(device)
    lrlast = 1e-3
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=lrlast)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
    print('==> Starting Training')
    if amp_enabled:
        autocast, scaler = get_autocast_and_scaler(xla_enabled)

    if xla_enabled:
        import torch_xla.distributed.parallel_loader as pl
        server = xp.start_server(port_number)
        train_device_loader = pl.MpDeviceLoader(dataloaders_dict['train'],
                                                device)
        # train_device_loader = dataloaders_dict['train']
    else:
        train_device_loader = dataloaders_dict['train']

    if dlprof_enabled and not xla_enabled and False:
        with torch.autograd.profiler.emit_nvtx():
            for epoch in range(num_epochs):
                epoch_time = time.time()
                # tracker = xm.RateTracker()
                print('Epoch {}/{}'.format(epoch, num_epochs - 1))
                print('-' * 10)
                model.train()  # Set model to training mode
                # Iterate over data.
                for step, (inputs,
                           sentiment) in enumerate(train_device_loader):
                    tracker = xm.RateTracker(
                    )  # Placing the tracker here frees it of I/O time.
                    if not xla_enabled:  # This section is not necessary (but doesn't cause any performance problems) for XLA
                        inputs = inputs.to(device)
                        sentiment = sentiment.to(device)
                    optimizer.zero_grad()
                    if amp_enabled:
                        loss, optimizer = loop_with_amp(
                            model, inputs, sentiment, optimizer, xla_enabled,
                            autocast, scaler)
                    else:
                        loss, optimizer = loop_without_amp(
                            model, inputs, sentiment, optimizer, xla_enabled)
                    tracker.add(inputs.size(0))
                    _train_update(device, step, loss, tracker, epoch, None)

                time_elapsed = time.time() - epoch_time
                print(
                    f'Epoch complete in {time_elapsed // 60}m {time_elapsed % 60}s'
                )
    else:
        for epoch in range(num_epochs):
            epoch_time = time.time()
            # tracker = xm.RateTracker()
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
            print('-' * 10)
            model.train()  # Set model to training mode
            # Iterate over data.
            if cpu_mem_usage:
                import resource
                print(
                    f" CPU Usage Before: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}"
                )
            for step, (inputs, sentiment) in enumerate(train_device_loader):
                if step == 5:
                    training_started.set()
                tracker = xm.RateTracker(
                )  # Placing the tracker here frees it of I/O time.
                if not xla_enabled:  # This section is not necessary (but doesn't cause any performance problems) for XLA
                    inputs = inputs.to(device)
                    sentiment = sentiment.to(device)
                optimizer.zero_grad()
                if amp_enabled:
                    loss, optimizer = loop_with_amp(model, inputs, sentiment,
                                                    optimizer, xla_enabled,
                                                    autocast, scaler)
                else:
                    loss, optimizer = loop_without_amp(model, inputs,
                                                       sentiment, optimizer,
                                                       xla_enabled)
                tracker.add(inputs.size(0))
                _train_update(device, step, loss, tracker, epoch, None)

            time_elapsed = time.time() - epoch_time
            print(
                f'Epoch complete in {time_elapsed // 60}m {time_elapsed % 60}s'
            )
    if xla_enabled and debug_enabled:
        import torch_xla.debug.metrics as met
        print(met.metrics_report())
Exemplo n.º 3
0
def train_mnist(flags, **kwargs):
    torch.manual_seed(1)

    if flags.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(flags.batch_size, 1, 28, 28),
                  torch.zeros(flags.batch_size, dtype=torch.int64)),
            sample_count=60000 // flags.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(flags.batch_size, 1, 28, 28),
                  torch.zeros(flags.batch_size, dtype=torch.int64)),
            sample_count=10000 // flags.batch_size // xm.xrt_world_size())
    else:
        train_dataset = datasets.MNIST(os.path.join(flags.datadir,
                                                    str(xm.get_ordinal())),
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        test_dataset = datasets.MNIST(os.path.join(flags.datadir,
                                                   str(xm.get_ordinal())),
                                      train=False,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=flags.batch_size,
            sampler=train_sampler,
            drop_last=flags.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=flags.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=flags.batch_size,
            drop_last=flags.drop_last,
            shuffle=False,
            num_workers=flags.num_workers)

    # Scale learning rate to num cores
    lr = flags.lr * xm.xrt_world_size()

    device = xm.xla_device()
    model = MNIST().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(flags.logdir)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum)
    loss_fn = nn.NLLLoss()

    # Start up client side profiler server.
    server = xp.start_server(flags.profiler_port)
    # Testing purpose only: set event for synchronization.
    if kwargs.get('worker_started'):
        kwargs.pop('worker_started').set()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            with xp.StepTrace('train_mnist', step_num=step):
                with xp.Trace('build_graph'):
                    optimizer.zero_grad()
                    output = model(data)
                    loss = loss_fn(output, target)
                    loss.backward()
                xm.optimizer_step(optimizer)

                tracker.add(flags.batch_size)
                if step % flags.log_steps == 0:
                    xm.add_step_closure(_train_update,
                                        args=(device, step, loss, tracker,
                                              writer))

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            with xp.StepTrace('test_mnist'):
                output = model(data)
                pred = output.max(1, keepdim=True)[1]
                correct += pred.eq(target.view_as(pred)).sum()
                total_samples += data.size()[0]

        accuracy = 100.0 * correct.item() / total_samples
        accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)
    accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, flags.num_epochs + 1):
        xm.master_print('Epoch {} train begin {}'.format(
            epoch, test_utils.now()))
        train_loop_fn(train_device_loader)
        xm.master_print('Epoch {} train end {}'.format(epoch,
                                                       test_utils.now()))

        accuracy = test_loop_fn(test_device_loader)
        xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
            epoch, test_utils.now(), accuracy))
        max_accuracy = max(accuracy, max_accuracy)
        test_utils.write_to_summary(writer,
                                    epoch,
                                    dict_to_write={'Accuracy/test': accuracy},
                                    write_xla_metrics=True)
        if flags.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
    return max_accuracy