コード例 #1
0
    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (
                self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})

        # Train!
        if is_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d", self.args.per_gpu_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", self.global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(
            epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()
        )

        self.eval_history = []
        for epoch in train_iterator:
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            epoch_iterator = tqdm(train_dataloader, desc=f"Epoch-{epoch}", disable=not self.is_local_master())
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                if self.args.do_aug:
                    if self.args.aug_type == 'span_cutoff':
                        step_loss = self._training_step_with_span_cutoff(model, inputs, optimizer)
                    elif self.args.aug_type == 'token_cutoff':
                        step_loss = self._training_step_with_token_cutoff(model, inputs, optimizer)
                    elif self.args.aug_type == 'dim_cutoff':
                        step_loss = self._training_step_with_dim_cutoff(model, inputs, optimizer)
                    else:
                        raise NotImplementedError
                else:
                    step_loss = self._training_step(model, inputs, optimizer)

                tr_loss += step_loss

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
                    if self.args.max_grad_norm > 0:
                        if self.args.fp16:
                            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
                        else:
                            torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                    if is_tpu_available():
                        xm.optimizer_step(optimizer)
                    else:
                        optimizer.step()

                    scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >= version.parse("1.4")
                            else scheduler.get_lr()[0]
                        )
                        logging_loss = tr_loss

                        print()
                        self._log(logs)

                        # if self.args.evaluate_during_training and self.args.save_steps % self.args.logging_steps == 0:
                        #     self.evaluate()

                    if self.is_world_master() and self.args.evaluate_during_training and \
                            self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        self.evaluate_and_save_model(model, optimizer, scheduler)

                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                train_iterator.close()
                break

            if self.is_world_master() and self.args.evaluate_during_training:
                self.evaluate_and_save_model(model, optimizer, scheduler)

            if self.args.tpu_metrics_debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.tb_writer:
            self.tb_writer.close()

        logger.info("\n\nTraining completed.\n\n")

        self.eval_history = sorted(self.eval_history, key=lambda x: x[0])
        for x in self.eval_history:
            del x[-1]
        report_results(self.eval_header, self.eval_history, axis=self.eval_key_axis)
        return TrainOutput(self.global_step, tr_loss / self.global_step)
コード例 #2
0
def distributed_init(cfg: FairseqConfig):
    if isinstance(cfg, Namespace):
        from fairseq.dataclass.utils import convert_namespace_to_omegaconf

        cfg = convert_namespace_to_omegaconf(cfg)

    if not cfg.common.tpu:
        if torch.distributed.is_available(
        ) and torch.distributed.is_initialized():
            warnings.warn(
                "Distributed is already initialized, cannot initialize twice!")
        else:
            logger.info("distributed init (rank {}): {}".format(
                cfg.distributed_training.distributed_rank,
                cfg.distributed_training.distributed_init_method,
            ))
            dist.init_process_group(
                backend=cfg.distributed_training.distributed_backend,
                init_method=cfg.distributed_training.distributed_init_method,
                world_size=cfg.distributed_training.distributed_world_size,
                rank=cfg.distributed_training.distributed_rank,
            )
            logger.info("initialized host {} as rank {}".format(
                socket.gethostname(),
                cfg.distributed_training.distributed_rank,
            ))

            # perform a dummy all-reduce to initialize the NCCL communicator
            if torch.cuda.is_available():
                dist.all_reduce(torch.zeros(1).cuda())

        cfg.distributed_training.distributed_rank = torch.distributed.get_rank(
        )
    else:
        assert xm.xrt_world_size(
        ) == cfg.distributed_training.distributed_world_size
        global _USE_XLA
        _USE_XLA = True
        cfg.distributed_training.device_id = xm.get_local_ordinal()
        cfg.distributed_training.distributed_rank = xm.get_ordinal()
        xm.rendezvous("distributed_init")  # wait for all workers
        xm.mark_step()

    if is_master(cfg.distributed_training):
        logging.getLogger().setLevel(logging.INFO)
    else:
        logging.getLogger().setLevel(logging.WARNING)

    if cfg.common.model_parallel_size > 1:
        try:
            from fairseq.model_parallel.megatron.mpu import (
                initialize_model_parallel,
                model_parallel_cuda_manual_seed,
            )
        except ImportError:
            raise ImportError("\n\nPlease install the megatron submodule:"
                              "\n\n  git submodule update --init "
                              "fairseq/model_parallel/megatron")
        global _USE_MEGATRON
        _USE_MEGATRON = True
        initialize_model_parallel(cfg.common.model_parallel_size)
        model_parallel_cuda_manual_seed(cfg.common.seed)
        model_part_number = get_model_parallel_rank()
        cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(
            model_part_number)

    return cfg.distributed_training.distributed_rank
コード例 #3
0
 def distributed_sampler_kwargs(self):
     return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
コード例 #4
0
def train_mnist():
    torch.manual_seed(1)

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=60000 // FLAGS.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        train_dataset = datasets.MNIST(os.path.join(FLAGS.datadir,
                                                    str(xm.get_ordinal())),
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        test_dataset = datasets.MNIST(os.path.join(FLAGS.datadir,
                                                   str(xm.get_ordinal())),
                                      train=False,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    # Scale learning rate to num cores
    lr = FLAGS.lr * xm.xrt_world_size()

    device = xm.xla_device()
    model = MNIST().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum)
    loss_fn = nn.NLLLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                xm.add_step_closure(_train_update,
                                    args=(device, x, loss, tracker))

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    max_accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print('Finished training epoch {}'.format(epoch))

        para_loader = pl.ParallelLoader(test_loader, [device])
        accuracy = test_loop_fn(para_loader.per_device_loader(device))
        max_accuracy = max(accuracy, max_accuracy)
        test_utils.write_to_summary(writer,
                                    epoch,
                                    dict_to_write={'Accuracy/test': accuracy},
                                    write_xla_metrics=True)
        if FLAGS.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
    return max_accuracy
コード例 #5
0
ファイル: trainer.py プロジェクト: yf1291/nlp4
def get_tpu_sampler(dataset: Dataset):
    if xm.xrt_world_size() <= 1:
        return RandomSampler(dataset)
    return DistributedSampler(dataset,
                              num_replicas=xm.xrt_world_size(),
                              rank=xm.get_ordinal())
コード例 #6
0
ファイル: test_train_mnist.py プロジェクト: whtngus/xla
def train_mnist():
    torch.manual_seed(1)

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=60000 // FLAGS.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        train_dataset = datasets.MNIST(FLAGS.datadir,
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        test_dataset = datasets.MNIST(FLAGS.datadir,
                                      train=False,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            sampler=test_sampler,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Scale learning rate to num cores
    lr = FLAGS.lr * max(len(devices), 1)
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    model_parallel = dp.DataParallel(MNIST, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.NLLLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(
                model.parameters(), lr=lr, momentum=FLAGS.momentum))
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = sum(accuracies) / len(accuracies)
        if FLAGS.metrics_debug:
            print(met.metrics_report())

    return accuracy
コード例 #7
0
def _setup_replication():
  if xm.xrt_world_size() > 1:
    device = xm.xla_device()
    xm.set_replication(str(device), [str(device)])
def train():
    torch.manual_seed(1)
    transform_train = transforms.Compose([
        transforms.RandomCrop(28, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.1307, ), (0.3081, )),
        transforms.RandomErasing(p=0.5,
                                 scale=(0.02, 0.4),
                                 ratio=(0.3, 3.3),
                                 value=0.4914),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307, ), (0.3081, )),
    ])

    train_dataset = FashionMNIST(root=os.path.join(FLAGS['data_dir'],
                                                   str(xm.get_ordinal())),
                                 train=True,
                                 download=True,
                                 transform=transform_train)

    test_dataset = FashionMNIST(root=os.path.join(FLAGS['data_dir'],
                                                  str(xm.get_ordinal())),
                                train=False,
                                download=False,
                                transform=transform_test)

    train_sampler = DistributedSampler(train_dataset,
                                       num_replicas=xm.xrt_world_size(),
                                       rank=xm.get_ordinal(),
                                       shuffle=True)

    train_loader = DataLoader(train_dataset,
                              batch_size=FLAGS['batch_size'],
                              sampler=train_sampler,
                              num_workers=FLAGS['num_workers'],
                              drop_last=True)

    test_loader = DataLoader(test_dataset,
                             batch_size=FLAGS['batch_size'],
                             shuffle=False,
                             num_workers=FLAGS['num_workers'],
                             drop_last=True)

    learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()

    device = xm.xla_device()

    classes = [
        "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal",
        "Shirt", "Sneaker", "Bag", "Ankle boot"
    ]

    print(device)

    class BasicBlock(nn.Module):
        def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
            super(BasicBlock, self).__init__()
            self.bn1 = nn.BatchNorm2d(in_planes)
            self.relu1 = nn.ReLU(inplace=True)
            self.conv1 = nn.Conv2d(in_planes,
                                   out_planes,
                                   kernel_size=3,
                                   stride=stride,
                                   padding=1,
                                   bias=False)
            self.bn2 = nn.BatchNorm2d(out_planes)
            self.relu2 = nn.ReLU(inplace=True)
            self.conv2 = nn.Conv2d(out_planes,
                                   out_planes,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   bias=False)
            self.droprate = dropRate
            self.equalInOut = (in_planes == out_planes)
            self.convShortcut = (not self.equalInOut) and nn.Conv2d(
                in_planes,
                out_planes,
                kernel_size=1,
                stride=stride,
                padding=0,
                bias=False) or None

        def forward(self, x):
            if not self.equalInOut:
                x = self.relu1(self.bn1(x))
            else:
                out = self.relu1(self.bn1(x))
            out = self.relu2(
                self.bn2(self.conv1(out if self.equalInOut else x)))
            if self.droprate > 0:
                out = F.dropout(out, p=self.droprate, training=self.training)
            out = self.conv2(out)
            return torch.add(x if self.equalInOut else self.convShortcut(x),
                             out)

    class NetworkBlock(nn.Module):
        def __init__(self,
                     nb_layers,
                     in_planes,
                     out_planes,
                     block,
                     stride,
                     dropRate=0.0):
            super(NetworkBlock, self).__init__()
            self.layer = self._make_layer(block, in_planes, out_planes,
                                          nb_layers, stride, dropRate)

        def _make_layer(self, block, in_planes, out_planes, nb_layers, stride,
                        dropRate):
            layers = []
            for i in range(nb_layers):
                layers.append(
                    block(i == 0 and in_planes or out_planes, out_planes,
                          i == 0 and stride or 1, dropRate))
            return nn.Sequential(*layers)

        def forward(self, x):
            return self.layer(x)

    class WideResNet(nn.Module):
        def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
            super(WideResNet, self).__init__()
            nChannels = [
                16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor
            ]
            assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
            n = (depth - 4) // 6
            block = BasicBlock
            # 1st conv before any network block
            self.conv1 = nn.Conv2d(1,
                                   nChannels[0],
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   bias=False)
            # 1st block
            self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1,
                                       dropRate)
            # 2nd block
            self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2,
                                       dropRate)
            # 3rd block
            self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2,
                                       dropRate)
            # global average pooling and classifier
            self.bn1 = nn.BatchNorm2d(nChannels[3])
            self.relu = nn.ReLU(inplace=True)
            self.fc = nn.Linear(nChannels[3], num_classes)
            self.nChannels = nChannels[3]

            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, math.sqrt(2. / n))
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()
                elif isinstance(m, nn.Linear):
                    m.bias.data.zero_()

        def forward(self, x):
            out = self.conv1(x)
            out = self.block1(out)
            out = self.block2(out)
            out = self.block3(out)
            out = self.relu(self.bn1(out))
            out = F.avg_pool2d(out, 7)
            out = out.view(-1, self.nChannels)
            return self.fc(out)

    model = WideResNet(num_classes=FLAGS['num_classes'],
                       depth=28,
                       widen_factor=10).to(device)

    loss_func = nn.CrossEntropyLoss()
    optimizer = SGD(model.parameters(),
                    lr=learning_rate,
                    momentum=0.9,
                    weight_decay=5e-4)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=[150, 225],
                                         gamma=0.1)

    def train_loop_fn(loader):
        tracker = RateTracker()
        model.train()
        print("Start Training")
        for counter, (images, labels) in enumerate(train_loader, start=1):
            outputs = model(images)
            loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS['batch_size'])
            if counter % FLAGS['log_steps'] == 0:
                print(
                    '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'
                    .format(xm.get_ordinal(), counter, loss.item(),
                            tracker.rate(), tracker.global_rate(),
                            time.asctime()),
                    flush=True)
            scheduler.step()

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        data, pred, target = None, None, None
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        print('[xla:{}] Accuracy={:.2f}%'.format(xm.get_ordinal(), accuracy),
              flush=True)
        return accuracy, data, pred, target

        # print(f'Epoch [{epoch+1}/{N_EPOCHS}] Loss= {(running_loss/counter)}')
        # if((epoch+1)%125==0):
        #     model.eval()
        #     with torch.no_grad():
        #         y_pred=[]
        #         for images, labels in test_loader:
        #             images = images.to(device)
        #             labels = labels.to(device)
        #             outputs = model(images)
        #             # max returns (value ,index)
        #             _, preds = torch.max(outputs, 1)
        #             y_pred+=preds.tolist()
        #         print(classification_report(test.targets, y_pred, target_names=classes))
        #     model.train()
        #     torch.save(model.state_dict(), DATA_DIR+"/WRN-28-10_%d"%(epoch+1))
        # scheduler.step()

    accuracy = 0.0
    data, pred, target = None, None, None
    for epoch in range(1, FLAGS['num_epochs'] + 1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print("Finished training epoch {}".format(epoch))

        para_loader = pl.ParallelLoader(test_loader, [device])
        accuracy, data, pred, target = test_loop_fn(
            para_loader.per_device_loader(device))
        if FLAGS['metrics_debug']:
            xm.master_print(met.metrics_report(), flush=True)
        scheduler.step()
    PATH = FLAGS['data_dir'] + "/WRN-28-10_F.pth"
    torch.save(model.state_dict(), PATH)
    return accuracy, data, pred, target
コード例 #9
0
def main(rank):
    
    #Seed - Added for TPU purposes
    torch.manual_seed(1)
       
    #Create log folder
    root = 'result_fg/'
    model = 'coco_model_'
    result_folder_name = 'images_' + FLAGS['log_dir']
    model_folder_name = 'models_' + FLAGS['log_dir']
    if not os.path.isdir(root):
        os.mkdir(root)
    if not os.path.isdir(root + result_folder_name):
        os.mkdir(root + result_folder_name)
    if not os.path.isdir(root + model_folder_name):
        os.mkdir(root + model_folder_name)
    
    #Save the script
    copyfile(os.path.basename(__file__), root + result_folder_name + '/' + os.path.basename(__file__))
    
    #Define transformation for dataset images - e.g scaling
    transform = transforms.Compose(
        [
            transforms.Scale((FLAGS['img_size'],FLAGS['img_size'])),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    ) 
    #Load dataset
    category_names = FLAGS['category_names'].split(',')
    
    #Serial Executor - This is needed to spread inside TPU for memory purposes
    SERIAL_EXEC = xmp.MpSerialExecutor()
    
    #Define Dataset
    dataset = SERIAL_EXEC.run(
        lambda: CocoData(
            root = FLAGS['train_imgs_path'],
            annFile = FLAGS['train_annotation_path'],
            category_names = category_names,
            transform=transform,
            final_img_size=FLAGS['img_size']
        )
    )
    
    #Discard images contain very small instances  
    dataset.discard_small(min_area=0.03, max_area=1)
    
    #Define data sampler - Added for TPU purposes
    train_sampler = DistributedSampler(
        dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )

    #Define data loader
    train_loader = DataLoader( #Modified for TPU purposes
        dataset,
        batch_size=FLAGS['batch_size'],
        sampler=train_sampler,
        num_workers=FLAGS['num_workers']
        # shuffle=True
    )

    #Define device - Added for TPU purposes
    device = xm.xla_device(devkind='TPU')

    #For evaluation define fixed masks and noises
    data_iter = iter(train_loader)
    sample_batched = data_iter.next()  
    x_fixed = sample_batched['image'][0:FLAGS['num_test_img']]
    x_fixed = Variable(x_fixed.to(device))
    y_fixed = sample_batched['single_fg_mask'][0:FLAGS['num_test_img']]
    y_fixed = Variable(y_fixed.to(device))
    z_fixed = torch.randn((FLAGS['num_test_img'],FLAGS['noise_size']))
    z_fixed = Variable(z_fixed.to(device))
    
    #Define networks
    generator = Generator_FG(
        z_dim=FLAGS['noise_size'],
        label_channel=len(category_names),
        num_res_blocks=FLAGS['num_res_blocks']
    )

    discriminator_glob = Discriminator(
        channels=3+len(category_names)
    )

    discriminator_instance = Discriminator(
        channels=3+len(category_names),
        input_size=FLAGS['local_patch_size']
    )

    WRAPPED_GENERATOR = xmp.MpModelWrapper(generator) #Added for TPU purposes
    WRAPPED_DISCRIMINATOR_GLOB = xmp.MpModelWrapper(discriminator) #Added for TPU purposes
    WRAPPED_DISCRIMINATOR_INSTANCE = xmp.MpModelWrapper(discriminator) #Added for TPU purposes

    G_fg = WRAPPED_GENERATOR.to(device) #Modified for TPU purposes
    D_glob = WRAPPED_DISCRIMINATOR.to(device) #Modified for TPU purposes
    D_instance = WRAPPED_DISCRIMINATOR.to(device) #Modified for TPU purposes
    
    #Load parameters from pre-trained models
    if FLAGS['pre_trained_model_path'] != None and FLAGS['pre_trained_model_epoch'] != None:
        try:
            G_fg.load_state_dict(xser.load(FLAGS['pre_trained_model_path'] + 'G_fg_epoch_' + FLAGS['pre_trained_model_epoch']))
            D_glob.load_state_dict(xser.load(FLAGS['pre_trained_model_path'] + 'D_glob_epoch_' + FLAGS['pre_trained_model_epoch']))
            D_instance.load_state_dict(xser.load(FLAGS['pre_trained_model_path'] + 'D_local_epoch_' + FLAGS['pre_trained_model_epoch']))
  
            xm.master_print('Parameters are loaded!')
        except:
            xm.master_print('Error: Pre-trained parameters are not loaded!')
            pass
    
    #Define interpolation operation
    up_instance =  nn.Upsample(
        size=(FLAGS['local_patch_size'],FLAGS['local_patch_size']),
        mode='bilinear'
    )
    
    #Define pooling operation for the case that image size and local patch size are mismatched
    pooling_instance = nn.Sequential()
    if FLAGS['local_patch_size']!=FLAGS['img_size']:
        pooling_instance.add_module(
            '0',
            nn.AvgPool2d(int(FLAGS['img_size']/FLAGS['local_patch_size']))
        )
        
    #Define training loss function - binary cross entropy
    BCE_loss = nn.BCELoss()
    
    #Define feature matching loss
    criterionVGG = VGGLoss()
    criterionVGG = criterionVGG.to(device) #Modified for TPU Purposes
         
    #Define optimizer
    G_local_optimizer = optim.Adam(
        G_fg.parameters(),
        lr=FLAGS['lr'],
        betas=(0.0, 0.9)
    )
    D_local_optimizer = optim.Adam(
        list(filter(lambda p: p.requires_grad, D_glob.parameters())) + list(filter(lambda p: p.requires_grad, D_instance.parameters())),
        lr=FLAGS['lr'],
        betas=(0.0,0.9)
    )

    #Deine learning rate scheduler
    scheduler_G = lr_scheduler.StepLR(
        G_local_optimizer,
        step_size=FLAGS['optim_step_size'],
        gamma=FLAGS['optim_gamma']
    )
    scheduler_D = lr_scheduler.StepLR(
        D_local_optimizer,
        step_size=FLAGS['optim_step_size'],
        gamma=FLAGS['optim_gamma']
    )
    
    #----------------------------TRAIN-----------------------------------------
    xm.master_print('training start!')
    tracker = xm.RateTracker() #Added for TPU reasons
    start_time = time.time()
    
    for epoch in range(FLAGS['train_epoch']):
        epoch_start_time = time.time()
        para_loader = pl.ParallelLoader(train_loader, [device]) #Added for TPU purposes
        loader = para_loader.per_device_loader(device) #Added for TPU purposes
         
        D_local_losses = []
        G_local_losses = []
    
        y_real_ = torch.ones(FLAGS['batch_size'])
        y_fake_ = torch.zeros(FLAGS['batch_size'])
        y_real_ = Variable(y_real_.to(device)) #Modified for TPU purposes
        y_fake_ = Variable(y_fake_.to(device)) #Modified for TPU purposes

        data_iter = iter(loader)
        num_iter = 0

        while num_iter < len(loader): #Modified for TPU purposes 
            j=0
            while j < FLAGS['critic_iter'] and num_iter < len(loader):
                j += 1
                sample_batched = data_iter.next()  
                num_iter += 1

                x_ = sample_batched['image']
                x_ = Variable(x_.to(device)) #Modified for TPU purposes

                y_ = sample_batched['single_fg_mask']
                y_ = Variable(y_.to(device)) #Modified for TPU purposes

                fg_mask = sample_batched['seg_mask']
                fg_mask = Variable(fg_mask.to(device)) #Modified for TPU purposes

                y_instances = sample_batched['mask_instance']
                bbox = sample_batched['bbox']
                
                mini_batch = x_.size()[0]
                if mini_batch != FLAGS['batch_size']:
                    break
                
                #Update discriminators - D 
                #Real examples
                D_glob.zero_grad()
                D_instance.zero_grad()
                    
                y_reduced = torch.sum(y_,1).clamp(0,1).view(y_.size(0),1,FLAGS['img_size'],FLAGS['img_size'])
                
                x_d = torch.cat([x_,fg_mask],1)
                
                x_instances = torch.zeros((FLAGS['batch_size'],3,FLAGS['local_patch_size'],FLAGS['local_patch_size']))
                x_instances = Variable(x_instances.to(device))
                y_instances = Variable(y_instances.to(device))
                y_instances = pooling_instance(y_instances)
                G_instances = torch.zeros((FLAGS['batch_size'],3,FLAGS['local_patch_size'],FLAGS['local_patch_size']))
                G_instances = Variable(G_instances.to(device))
                      
                #Obtain instances
                for t in range(x_d.size()[0]):
                    x_instance = x_[t,0:3,bbox[0][t]:bbox[1][t],bbox[2][t]:bbox[3][t]] 
                    x_instance = x_instance.contiguous().view(1,x_instance.size()[0],x_instance.size()[1],x_instance.size()[2]) 
                    x_instances[t] = up_instance(x_instance)
                    
                D_result_instance = D_instance(torch.cat([x_instances,y_instances],1)).squeeze()       
                D_result = D_glob(x_d).squeeze()
                D_real_loss = BCE_loss(D_result, y_real_) +  BCE_loss(D_result_instance, y_real_)
                D_real_loss.backward()
                
                #Fake examples
                z_ = torch.randn((mini_batch,FLAGS['noise_size']))
                z_ = Variable(z_.to(device))
    
                #Generate fake images
                G_fg_result = G_fg(z_,y_, torch.mul(x_,(1-y_reduced)))
                G_result_d = torch.cat([G_fg_result,fg_mask],1) 
                            
                #Obtain fake instances
                for t in range(x_d.size()[0]):
                    G_instance = G_result_d[t,0:3,bbox[0][t]:bbox[1][t],bbox[2][t]:bbox[3][t]] 
                    G_instance = G_instance.contiguous().view(1,G_instance.size()[0],G_instance.size()[1],G_instance.size()[2]) 
                    G_instances[t] = up_instance(G_instance)
                
                
                D_result_instance = D_instance(torch.cat([G_instances,y_instances],1).detach()).squeeze() 
                D_result = D_glob(G_result_d.detach()).squeeze() 
                D_fake_loss = BCE_loss(D_result, y_fake_) +  BCE_loss(D_result_instance, y_fake_)
                D_fake_loss.backward()

                xm.optimizer_step(D_local_optimizer) #Modified for TPU purposes
                
                D_train_loss = D_real_loss + D_fake_loss
                D_local_losses.append(D_train_loss.data[0])
    
            if mini_batch != FLAGS['batch_size']:
                break  
            
            #Update generator G
            G_fg.zero_grad()   
            D_result = D_glob(G_result_d).squeeze() 
            D_result_instance = D_instance(torch.cat([G_instances,y_instances],1)).squeeze() 
            G_train_loss = (1-FLAGS['trade_off_G'])*BCE_loss(D_result, y_real_) + FLAGS['trade_off_G']*BCE_loss(D_result_instance, y_real_) 
            
            #Feature matching loss between generated image and corresponding ground truth
            FM_loss = criterionVGG(G_fg_result, x_)
            
            #Reconstruction loss
            Recon_loss = mse_loss(torch.mul(x_,(1-y_reduced) ), torch.mul(G_fg_result,(1-y_reduced))  )
    
            total_loss = G_train_loss + FLAGS['lambda_FM']*FM_loss + FLAGS['lambda_recon']*Recon_loss
            total_loss.backward() 

            xm.optimizer_step(G_local_optimizer)

            G_local_losses.append(G_train_loss.data[0])
    
            xm.master_print('loss_d: %.3f, loss_g: %.3f' % (D_train_loss.data[0],G_train_loss.data[0]))
            if (num_iter % 100) == 0:
                xm.master_print('%d - %d complete!' % ((epoch+1), num_iter))
                xm.master_print(result_folder_name)

        #Modified location of the scheduler step to avoid warning
        scheduler_G.step()
        scheduler_D.step()

        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time
        xm.master_print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), FLAGS['train_epoch'], per_epoch_ptime, torch.mean(torch.FloatTensor(D_local_losses)), torch.mean(torch.FloatTensor(G_local_losses))))
    
        #Save images
        G_fg.eval()
        
        if epoch == 0:
            show_result(
                (epoch+1),
                x_fixed,
                save=True,
                path=root + result_folder_name+ '/' + model + str(epoch + 1 ) + '_gt.png'
            )
            for t in range(y_fixed.size()[1]):
                show_result(
                    (epoch+1),
                    y_fixed[:,t:t+1,:,:],
                    save=True,
                    path=root + result_folder_name+ '/' + model + str(epoch + 1 ) +'_'+ str(t) +'_masked.png'
                )
            
        show_result(
            (epoch+1),
            G_fg(
                z_fixed,
                y_fixed,
                torch.mul(
                    x_fixed,
                    (1-torch.sum(y_fixed,1).view(y_fixed.size(0),1,FLAGS['img_size'],FLAGS['img_size']))
                )
            ),
            save=True,
            path=root + result_folder_name+ '/' + model + str(epoch + 1 ) + '_fg.png'
        )
        
        G_fg.train()
        
        #Save model params
        if FLAGS['save_models'] and (epoch>11 and epoch % 10 == 0 ):
            xser.save(
                G_fg.state_dict(),
                root + model_folder_name + '/' + model + 'G_fg_epoch_'+str(epoch)+'.pth'
                master_only=True
            )
            xser.save(
                D_glob.state_dict(),
                root + model_folder_name + '/' + model + 'D_glob_epoch_'+str(epoch)+'.pth'
                master_only=True
            )
            xser.save(
                D_instance.state_dict(),
                root + model_folder_name + '/' + model + 'D_local_epoch_'+str(epoch)+'.pth'
                master_only=True
            )
                         
    end_time = time.time()
    total_ptime = end_time - start_time
    xm.master_print("Training finish!... save training results")
    xm.master_print('Training time: ' + str(total_ptime))
コード例 #10
0
ファイル: tpu.py プロジェクト: hyang0129/nsga-net
    def run():
        """
        Main function to setup the training loop and evaluation loop.
        See comments for detailed explanation.

        Returns:
            None, but it saves the model weights and model performance, based on the get_map_fn arguments

        """

        # xla will assign a device for each forked run of this function
        device = xm.xla_device()

        # determine if this fork is the master fork to avoid logging and print 8 times
        master = xm.is_master_ordinal()

        if master:
            logger.info("running at batch size %i" % batch_size)

        criterion = nn.CrossEntropyLoss()

        criterion.to(device)
        model = WRAPPED_MODEL.to(device)

        # standard data prep
        CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
        CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

        train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
            ]
        )

        if args.cutout > 0:
            train_transform.transforms.append(Cutout(args.cutout))

        train_data = CifarDataset(transform=train_transform)

        # distributed samples ensure data is sharded to each tpu core
        # if you do not use this, you are only using 1 of the 8 cores
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True,
        )

        train_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=batch_size//xm.xrt_world_size(),
            sampler=train_sampler,
            drop_last=True,
            num_workers=0,
        )

        valid_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
            ]
        )

        valid_data = my_cifar10.CIFAR10(
            root=data_root, train=False, download=False, transform=valid_transform
        )

        valid_sampler = torch.utils.data.distributed.DistributedSampler(
            valid_data,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False,
        )

        valid_queue = torch.utils.data.DataLoader(
            valid_data,
            sampler=valid_sampler,
            batch_size=batch_size//xm.xrt_world_size(),
            drop_last=True,
            num_workers=0,
        )

        # standard optimizer stuff
        parameters = filter(lambda p: p.requires_grad, model.parameters())

        if args.opt == "sgd":

            optimizer = torch.optim.SGD(
                parameters,
                args.learning_rate,
                momentum=momentum,
                weight_decay=args.weight_decay,
            )
        elif args.opt == "lamb":
            optimizer = Lamb(
                parameters, lr=args.learning_rate, weight_decay=weight_decay
            )
        else:
            raise NameError("Unknown Optimizer %s" % args.opt)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(epochs))

        # training by epoch loop
        for epoch in range(epochs):

            # the model needs a droprate, so just assign it
            model.droprate = drop_path_prob * epoch / epochs

            start = datetime.datetime.now()
            st = start.strftime("%Y-%m-%d %H:%M:%S")

            if master:
                logger.info("starting epoch %i at %s" % (epoch, st))

            # parallel loader necessary to load data in parallel to each core
            para_loader = pl.ParallelLoader(train_queue, [device]).per_device_loader(
                device
            )
            correct, train_loss, total = train(
                para_loader, model, criterion, optimizer, params, device
            )

            train_acc = 100 * correct / total

            # collect the train accuracies from all cores
            train_acc = xm.mesh_reduce("avg acc", train_acc, np.mean)

            end = datetime.datetime.now()
            duration = (end - start).total_seconds()

            if master:
                logger.info("train_acc %f duration %f" % (train_acc, duration))

            scheduler.step()

        # validate using 8 cores and collect results
        valid_acc, valid_obj = infer(valid_queue, model, criterion, device)
        valid_acc = xm.mesh_reduce("val avg acc", valid_acc, np.mean)

        if master:
            logger.info("valid_acc %f" % valid_acc)

        # count flops
        _ = add_flops_counting_methods(model)
        model.eval()
        model.start_flops_count()
        random_data = torch.randn(1, 3, 32, 32)
        model(torch.autograd.Variable(random_data).to(device))
        n_flops = np.round(model.compute_average_flops_cost() / 1e6, 4)
        n_flops = xm.mesh_reduce("flops", n_flops, np.mean)

        if master:
            logger.info("flops %f" % n_flops)

        if master:
            logger.info("saving")

        # save weights and results

        xm.save([valid_acc, n_flops], "results.pt")
コード例 #11
0
    def map_fn(self, index, train_dataset, dev_dataset, lr, epochs, batch_size, callbacks):
        if self.using_tpu is True:
            device = xm.xla_device()
        else:
            device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        train_loader = self.make_loader(train_dataset, batch_size, 'train')
        dev_loader = self.make_loader(dev_dataset, batch_size, 'dev')

        model = self.model.to(device)
        if self.using_tpu:
            opt = self.Opt([param for param in model.parameters() if param.requires_grad],
                           lr=lr*xm.xrt_world_size(), weight_decay=1e-4)  # hard coding
        else:
            opt = self.Opt([param for param in model.parameters() if param.requires_grad],
                           lr=lr, weight_decay=1e-4)  # hard coding

        loss_fn = self.Loss_fn(from_logits=True)

        callback_kwargs = {
            "model": model,
            "eval_dic": self.dev_eval,
        }

        for callback in callbacks:
            callback.train_init(**callback_kwargs)

        for epoch in range(epochs):
            if self.using_tpu:
                xm.rendezvous("training is starting!")
                if xm.is_master_ordinal():
                    print(f"\nepoch : {epoch+1} / {epochs}")
                now_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
            else:
                print(f"epoch : {epoch+1} / {epochs}")
                now_train_loader = train_loader
            model.train()
            for step, batch in enumerate(now_train_loader):
                logits, y, loss = self.compute_batch(model, batch, device, loss_fn, opt, phase='train')

                if self.using_tpu:
                    xm.rendezvous("update is starting!")
                    self.update(logits, y, loss, 'train', batch_size)
                    xm.rendezvous("update is ended!")
                    if xm.is_master_ordinal():
                        self.show_log(step*xm.xrt_world_size(), train_dataset, batch_size, 'train')
                else:
                    self.update(logits, y, loss, 'train', batch_size)
                    self.show_log(step, train_dataset, batch_size, 'train')

            if self.using_tpu:
                xm.rendezvous("batch is done!")
                if xm.is_master_ordinal():
                    print()
            else:
                print()

            model.eval()
            with torch.no_grad():
                if self.using_tpu:
                    now_dev_loader = pl.ParallelLoader(dev_loader, [device]).per_device_loader(device)
                else:
                    now_dev_loader = dev_loader
                for step, batch in enumerate(now_dev_loader):
                    logits, y, loss = self.compute_batch(model, batch, device, loss_fn, opt, phase='dev')

                    if self.using_tpu:
                        xm.rendezvous("update is starting!")
                        self.update(logits, y, loss, 'dev', batch_size)
                        xm.rendezvous("eval update is ended!")
                        if xm.is_master_ordinal():
                            self.show_log(step*xm.xrt_world_size(), dev_dataset, batch_size, 'dev')
                    else:
                        self.update(logits, y, loss, 'dev', batch_size)
                        self.show_log(step, dev_dataset, batch_size, 'dev')

                if self.using_tpu:
                    xm.rendezvous("batch is done!")
                    if xm.is_master_ordinal():
                        print()
                else:
                    print()
            self.on_epoch_end(callbacks)

        if self.using_tpu:
            xm.rendezvous("training is over!")
コード例 #12
0
def run():
    df1 = pd.read_csv("../input/jigsaw-multilingual-toxic-comment-train.csv",
                      usecols=['comment_text', 'toxic'])
    df1 = pd.read_csv("../input/jigsaw-unintended-bias-train.csv",
                      usecols=['comment_text', 'toxic'])

    #combined df1 and df2 and made big dataframe
    df_train = pd.concat([df1, df2], axis=0).reset_index(drop=True)

    #validation dataframe has been given by kaggle
    df_valid - pd.read_csv("../input/validation.csv")

    train_dataset = dataset.BERTDataset(
        comment_text=df_train.comment_text.values,
        target=df_train.toxic.values)

    #--------------------------------------
    #write sampler if using tpu else not
    train_sampler = torch.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    #----------------------------------------

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN_BATCH_SIZE,
        num_workers=4,
        sampler=train_sampler,
        #problem with tpu when using torch_xla is that if batch size is not equal then it's going to crash , so use drop_last
        drop_last=True)

    valid_dataset = dataset.BERTDataset(
        comment_text=df_valid.comment_text.values,
        target=df_valid.toxic.values)

    #--------------------------------------
    #write sampler if using tpu else not
    valid_sampler = torch.data.distributed.DistributedSampler(
        valid_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    #----------------------------------------------

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        num_workers=1,
        sampler=valid_sampler,
        #no need of drop_last here
    )

    device = xm.xla_device()  #xla_device means tpu
    model = BERTBaseUncased()
    # model.to(device)  #no need to move data on device

    #specify what parameters you want to train
    param_optimizer = list(model.named_parameters())

    #we don't want any deacy for these layer names such as bias and othr following things
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]

    optimizer_parameters = [
        {
            #don't decay weight for above no_decay list else decay
            "params": [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.001,
        },
        {
            "params":
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            "weight_decay":
            0.0,
        },
    ]

    num_train_steps = int(
        len(df_train) / config.TRAIN_BATCH_SIZE / xm.xrt_world_size() *
        config.EPOCHS)

    lr = 3e-5 * xm.xrt_world_size()
    #experiment with lr
    optimizer = AdamW(optimizer_parameters, lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=num_train_steps)

    best_accuracy = 0
    for epoch in range(config.EPOCHS):

        #parallel loader for tpus
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        engine.train_fn(para_loader.per_device_loader(device), model,
                        optimizer, device, scheduler)

        parallel_loader = pl.ParallelLoader(valid_data_loader, [device])
        outputs, targets = engine.eval_fn(
            para_loader.per_device_loader(device), model, device)

        #threshold the target instead of output
        targets = np.array(targets) >= 0.5
        accuracy = metrics.accuracy_score(targets, outputs)
        print(f"Accuracy Score = {accuracy}")
        if accuracy > best_accuracy:

            #instead of torch.save use xm.save
            xm.save(model.state_dict(), config.MODEL_PATH)
            best_accuracy = accuracy
コード例 #13
0
ファイル: utils.py プロジェクト: pytorch/xla
def dummy_all_gather(value, dim=0, groups=None):
    """A dummy op for debugging with the same output shape as all_gather"""
    repeat_num = [1] * value.dim()
    repeat_num[dim] = xm.xrt_world_size()
    return value.repeat(tuple(repeat_num))
コード例 #14
0
def _mp_fn(rank, flags):
    device = xm.xla_device()
    net.to(device)

    train_sampler = DistributedSamplerWrapper(
        sampler=BalanceClassSampler(labels=train_dataset.get_labels(), mode="downsampling"),
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=TrainGlobalConfig.batch_size,
        sampler=train_sampler,
        pin_memory=False,
        drop_last=True,
        num_workers=TrainGlobalConfig.num_workers,
    )
    validation_sampler = torch.utils.data.distributed.DistributedSampler(
        validation_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    validation_loader = torch.utils.data.DataLoader(
        validation_dataset,
        batch_size=TrainGlobalConfig.batch_size,
        sampler=validation_sampler,
        pin_memory=False,
        drop_last=False,
        num_workers=TrainGlobalConfig.num_workers
    )
    validation_tune_sampler = torch.utils.data.distributed.DistributedSampler(
        validation_tune_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )
    validation_tune_loader = torch.utils.data.DataLoader(
        validation_tune_dataset,
        batch_size=TrainGlobalConfig.batch_size,
        sampler=validation_tune_sampler,
        pin_memory=False,
        drop_last=False,
        num_workers=TrainGlobalConfig.num_workers
    )
    test_sampler = torch.utils.data.distributed.DistributedSampler(
        test_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=TrainGlobalConfig.batch_size,
        sampler=test_sampler,
        pin_memory=False,
        drop_last=False,
        num_workers=TrainGlobalConfig.num_workers
    )

    if rank == 0:
        time.sleep(1)
    
    fitter = TPUFitter(model=net, device=device, config=TrainGlobalConfig)
    fitter.fit(train_loader, validation_loader)
    fitter.run_tuning_and_inference(test_loader, validation_tune_loader)
コード例 #15
0
def tpu_distributed() -> bool:
    return _TPU_AVAILABLE and xm.xrt_world_size() > 1
コード例 #16
0
def build_dataloader_and_sampler(
    dataset_instance: torch.utils.data.Dataset, training_config: DictConfig
) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
    """Builds and returns a dataloader along with its sample

    Args:
        dataset_instance (torch.utils.data.Dataset): Instance of dataset for which
            dataloader has to be created
        training_config (DictConfig): Training configuration; required
            for infering params for dataloader

    Returns:
        Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
            Tuple of Dataloader and Sampler instance
    """
    from mmf.common.batch_collator import BatchCollator

    num_workers = training_config.num_workers
    pin_memory = training_config.pin_memory

    other_args = {}

    # IterableDataset returns batches directly, so no need to add Sampler
    # or batch size as user is expected to control those. This is a fine
    # assumption for now to not support single item based IterableDataset
    # as it will add unnecessary complexity and config parameters
    # to the codebase
    if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
        other_args = _add_extra_args_for_dataloader(dataset_instance,
                                                    other_args)

    if is_xla():
        other_args["sampler"] = torch.utils.data.DistributedSampler(
            dataset_instance,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=other_args["shuffle"],
        )
        other_args.pop("shuffle")

    loader = torch.utils.data.DataLoader(
        dataset=dataset_instance,
        pin_memory=pin_memory,
        collate_fn=BatchCollator(dataset_instance.dataset_name,
                                 dataset_instance.dataset_type),
        num_workers=num_workers,
        drop_last=False,  # see also MultiDatasetLoader.__len__
        **other_args,
    )

    if is_xla():
        device = xm.xla_device()
        loader = pl.MpDeviceLoader(loader, device)

    if num_workers >= 0:
        # Suppress leaking semaphore warning
        os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    loader.dataset_type = dataset_instance.dataset_type

    return loader, other_args.get("sampler", None)
コード例 #17
0
def train_loop(folds, fold):

    if CFG.device == 'GPU':
        LOGGER.info(f"========== fold: {fold} training ==========")
    elif CFG.device == 'TPU':
        if CFG.nprocs == 1:
            LOGGER.info(f"========== fold: {fold} training ==========")
        elif CFG.nprocs == 8:
            xm.master_print(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)
    valid_labels = valid_folds[CFG.target_cols].values

    train_dataset = TrainDataset(train_folds,
                                 transform=get_transforms(data='train'))
    valid_dataset = TrainDataset(valid_folds,
                                 transform=get_transforms(data='valid'))

    if CFG.device == 'GPU':
        train_loader = DataLoader(train_dataset,
                                  batch_size=CFG.batch_size,
                                  shuffle=True,
                                  num_workers=CFG.num_workers,
                                  pin_memory=True,
                                  drop_last=True)
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=CFG.batch_size * 2,
                                  shuffle=False,
                                  num_workers=CFG.num_workers,
                                  pin_memory=True,
                                  drop_last=False)

    elif CFG.device == 'TPU':
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=CFG.batch_size,
                                                   sampler=train_sampler,
                                                   drop_last=True,
                                                   num_workers=CFG.num_workers)

        valid_sampler = torch.utils.data.distributed.DistributedSampler(
            valid_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False)
        valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                                   batch_size=CFG.batch_size *
                                                   2,
                                                   sampler=valid_sampler,
                                                   drop_last=False,
                                                   num_workers=CFG.num_workers)

    # ====================================================
    # scheduler
    # ====================================================
    def get_scheduler(optimizer):
        if CFG.scheduler == 'ReduceLROnPlateau':
            scheduler = ReduceLROnPlateau(optimizer,
                                          mode='min',
                                          factor=CFG.factor,
                                          patience=CFG.patience,
                                          verbose=True,
                                          eps=CFG.eps)
        elif CFG.scheduler == 'CosineAnnealingLR':
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=CFG.T_max,
                                          eta_min=CFG.min_lr,
                                          last_epoch=-1)
        elif CFG.scheduler == 'CosineAnnealingWarmRestarts':
            scheduler = CosineAnnealingWarmRestarts(optimizer,
                                                    T_0=CFG.T_0,
                                                    T_mult=1,
                                                    eta_min=CFG.min_lr,
                                                    last_epoch=-1)
        return scheduler

    # ====================================================
    # model & optimizer
    # ====================================================
    if CFG.device == 'TPU':
        device = xm.xla_device()
    elif CFG.device == 'GPU':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = CustomResNet200D(CFG.model_name, pretrained=False)
    model.load_state_dict(
        torch.load(CFG.student, map_location=torch.device('cpu'))['model'])
    model.to(device)

    optimizer = Adam(model.parameters(),
                     lr=CFG.lr,
                     weight_decay=CFG.weight_decay,
                     amsgrad=False)
    scheduler = get_scheduler(optimizer)

    # ====================================================
    # loop
    # ====================================================
    train_fc = FocalLoss(alpha=CFG.alpha,
                         gamma=CFG.gamma,
                         logits=True,
                         reduce=True)
    valid_fc = nn.BCEWithLogitsLoss()

    best_score = 0.
    best_loss = np.inf

    for epoch in range(CFG.epochs):

        start_time = time.time()

        # train
        if CFG.device == 'TPU':
            if CFG.nprocs == 1:
                avg_loss = train_fn(train_loader, model, train_fc, valid_fc,
                                    optimizer, epoch, scheduler, device)
            elif CFG.nprocs == 8:
                para_train_loader = pl.ParallelLoader(train_loader, [device])
                avg_loss = train_fn(
                    para_train_loader.per_device_loader(device), model,
                    train_fc, valid_fc, optimizer, epoch, scheduler, device)
        elif CFG.device == 'GPU':
            avg_loss = train_fn(train_loader, model, train_fc, valid_fc,
                                optimizer, epoch, scheduler, device)

        # eval
        if CFG.device == 'TPU':
            if CFG.nprocs == 1:
                avg_val_loss, preds, _ = valid_fn(valid_loader, model,
                                                  valid_fc, device)
            elif CFG.nprocs == 8:
                para_valid_loader = pl.ParallelLoader(valid_loader, [device])
                avg_val_loss, preds, valid_labels = valid_fn(
                    para_valid_loader.per_device_loader(device), model,
                    valid_fc, device)
                preds = idist.all_gather(torch.tensor(preds)).to('cpu').numpy()
                valid_labels = idist.all_gather(
                    torch.tensor(valid_labels)).to('cpu').numpy()
        elif CFG.device == 'GPU':
            avg_val_loss, preds, _ = valid_fn(valid_loader, model, valid_fc,
                                              device)

        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(avg_val_loss)
        elif isinstance(scheduler, CosineAnnealingLR):
            scheduler.step()
        elif isinstance(scheduler, CosineAnnealingWarmRestarts):
            scheduler.step()

        # scoring
        score, scores = get_score(valid_labels, preds)

        elapsed = time.time() - start_time

        if CFG.device == 'GPU':
            LOGGER.info(
                f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s'
            )
            LOGGER.info(
                f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}'
            )
        elif CFG.device == 'TPU':
            if CFG.nprocs == 1:
                LOGGER.info(
                    f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s'
                )
                LOGGER.info(
                    f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}'
                )
            elif CFG.nprocs == 8:
                xm.master_print(
                    f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s'
                )
                xm.master_print(
                    f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}'
                )

        if score > best_score:
            best_score = score
            if CFG.device == 'GPU':
                LOGGER.info(
                    f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model'
                )
                torch.save({
                    'model': model.state_dict(),
                    'preds': preds
                }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth')
            elif CFG.device == 'TPU':
                if CFG.nprocs == 1:
                    LOGGER.info(
                        f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model'
                    )
                elif CFG.nprocs == 8:
                    xm.master_print(
                        f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model'
                    )
                xm.save({
                    'model': model,
                    'preds': preds
                }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth')

        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            if CFG.device == 'GPU':
                LOGGER.info(
                    f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model')
                torch.save({
                    'model': model.state_dict(),
                    'preds': preds
                }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_loss.pth')
            elif CFG.device == 'TPU':
                if CFG.nprocs == 1:
                    LOGGER.info(
                        f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model'
                    )
                elif CFG.nprocs == 8:
                    xm.master_print(
                        f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model'
                    )
                xm.save({
                    'model': model,
                    'preds': preds
                }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_loss.pth')

        # inference用に全て保存しておく
        if CFG.device == 'TPU':
            xm.save({'model': model}, OUTPUT_DIR +
                    f'{CFG.model_name}_fold{fold}_epoch{epoch+1}.pth')
        elif CFG.device == 'GPU':
            torch.save({'model': model.state_dict()}, OUTPUT_DIR +
                       f'{CFG.model_name}_fold{fold}_epoch{epoch+1}.pth')

        if CFG.nprocs != 8:
            check_point = torch.load(
                OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth')
            for c in [f'pred_{c}' for c in CFG.target_cols]:
                valid_folds[c] = np.nan
            valid_folds[[f'pred_{c}'
                         for c in CFG.target_cols]] = check_point['preds']

    return valid_folds
コード例 #18
0
def train_imagenet():
    print("==> Preparing data..")
    img_dim = get_model_property("img_dim")
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(
                torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            ),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size(),
        )
        if FLAGS.validate:
            test_loader = xu.SampleGenerator(
                data=(
                    torch.zeros(FLAGS.test_set_batch_size, 3, img_dim,
                                img_dim),
                    torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64),
                ),
                sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size(),
            )
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, "train"),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]),
        )
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        if FLAGS.validate:
            test_dataset = torchvision.datasets.ImageFolder(
                os.path.join(FLAGS.datadir, "val"),
                # Matches Torchvision's eval transforms except Torchvision uses size
                # 256 resize for all models both here and in the train loader. Their
                # version crashes during training on 299x299 images, e.g. inception.
                transforms.Compose([
                    transforms.Resize(resize_dim),
                    transforms.CenterCrop(img_dim),
                    transforms.ToTensor(),
                    normalize,
                ]),
            )

        train_sampler, test_sampler = None, None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            if FLAGS.validate:
                test_sampler = torch.utils.data.distributed.DistributedSampler(
                    test_dataset,
                    num_replicas=xm.xrt_world_size(),
                    rank=xm.get_ordinal(),
                    shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers,
        )
        if FLAGS.validate:
            test_loader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=FLAGS.test_set_batch_size,
                sampler=test_sampler,
                drop_last=FLAGS.drop_last,
                shuffle=False,
                num_workers=FLAGS.num_workers,
            )

    device = xm.xla_device()
    model = get_model_property("model_fn")().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(),
                          lr=FLAGS.lr,
                          momentum=FLAGS.momentum,
                          weight_decay=1e-4)
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         xm.xrt_world_size())
    lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
        optimizer,
        scheduler_type=getattr(FLAGS, "lr_scheduler_type", None),
        scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None),
        scheduler_divide_every_n_epochs=getattr(
            FLAGS, "lr_scheduler_divide_every_n_epochs", None),
        num_steps_per_epoch=num_training_steps_per_epoch,
        summary_writer=writer,
    )
    loss_fn = nn.CrossEntropyLoss()
    scaler = GradScaler()

    def train_loop_fn(loader, epoch):
        if FLAGS.fine_grained_metrics:
            epoch_start_time = time.time()
            step_latency_tracker, bwd_latency_tracker, fwd_latency_tracker = [], [], []
        else:
            tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            if FLAGS.fine_grained_metrics:
                step_start_time = time.time()
            optimizer.zero_grad()
            if FLAGS.fine_grained_metrics:
                fwd_start_time = time.time()
            with autocast():
                output = model(data)
                loss = loss_fn(output, target)
            if FLAGS.fine_grained_metrics:
                fwd_end_time = time.time()
                fwd_latency = fwd_end_time - fwd_start_time

                bwd_start_time = time.time()
            scaler.scale(loss).backward()
            gradients = xm._fetch_gradients(optimizer)
            xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
            scaler.step(optimizer)
            scaler.update()
            xm.mark_step()
            if lr_scheduler:
                lr_scheduler.step()
            if FLAGS.fine_grained_metrics:
                bwd_end_time = time.time()
                bwd_latency = bwd_end_time - bwd_start_time

                step_latency = bwd_end_time - step_start_time
                step_latency_tracker.append(step_latency)
                bwd_latency_tracker.append(bwd_latency)
                fwd_latency_tracker.append(fwd_latency)
            else:
                tracker.add(FLAGS.batch_size)
            if step % FLAGS.log_steps == 0:
                if FLAGS.fine_grained_metrics:
                    print('FineGrainedMetrics :: Epoch={} Step={} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\
                                                epoch, step, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
                else:
                    # _train_update(device, step, loss, tracker, epoch, writer)
                    xm.add_step_closure(_train_update,
                                        args=(device, step, loss, tracker,
                                              epoch, writer))
        if FLAGS.fine_grained_metrics:
            epoch_end_time = time.time()
            epoch_latency = epoch_end_time - epoch_start_time
            print('FineGrainedMetrics :: Epoch={} Epoch(s)={:.} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\
                                            epoch, epoch_latency, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))

    def test_loop_fn(loader, epoch):
        total_samples, correct = 0, 0
        model.eval()
        for step, (data, target) in enumerate(loader):
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]
            if step % FLAGS.log_steps == 0:
                test_utils.print_test_update(device, None, epoch, step)
                # xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step))
        accuracy = 100.0 * correct.item() / total_samples
        accuracy = xm.mesh_reduce("test_accuracy", accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    if FLAGS.validate:
        test_device_loader = pl.MpDeviceLoader(test_loader, device)
        accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        xm.master_print("Epoch {} train begin {}".format(
            epoch, test_utils.now()))
        train_loop_fn(train_device_loader, epoch)
        xm.master_print("Epoch {} train end {}".format(epoch,
                                                       test_utils.now()))
        if FLAGS.validate:
            accuracy = test_loop_fn(test_device_loader, epoch)
            xm.master_print("Epoch {} test end {}, Accuracy={:.2f}".format(
                epoch, test_utils.now(), accuracy))
            max_accuracy = max(accuracy, max_accuracy)
            test_utils.write_to_summary(
                writer,
                epoch,
                dict_to_write={"Accuracy/test": accuracy},
                write_xla_metrics=True)
        if FLAGS.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    if FLAGS.validate:
        xm.master_print("Max Accuracy: {:.2f}%".format(max_accuracy))
    return max_accuracy if FLAGS.validate else None
コード例 #19
0
def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int):
    if xm.xrt_world_size() <= 1:
        return RandomSampler(dataset)
    return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
コード例 #20
0
def create_loader(dataset,
                  input_size,
                  batch_size,
                  is_training=False,
                  use_prefetcher=True,
                  rand_erase_prob=0.,
                  rand_erase_mode='const',
                  interpolation='bilinear',
                  mean=IMAGENET_DEFAULT_MEAN,
                  std=IMAGENET_DEFAULT_STD,
                  num_workers=1,
                  distributed=False,
                  crop_pct=None,
                  collate_fn=None,
                  tf_preprocessing=False,
                  use_auto_aug=False,
                  use_mixcut=False):
    if isinstance(input_size, tuple):
        img_size = input_size[-2:]
    else:
        img_size = input_size

    if tf_preprocessing and use_prefetcher:
        from timm.data.tf_preprocessing import TfPreprocessTransform
        transform = TfPreprocessTransform(is_training=is_training,
                                          size=img_size)
    else:
        if is_training:
            transform = transforms_imagenet_train(
                img_size,
                interpolation=interpolation,
                use_prefetcher=use_prefetcher,
                mean=mean,
                std=std,
                use_auto_aug=use_auto_aug,
                use_mix_cut=use_mixcut)
        else:
            transform = transforms_imagenet_eval(img_size,
                                                 interpolation=interpolation,
                                                 use_prefetcher=use_prefetcher,
                                                 mean=mean,
                                                 std=std,
                                                 crop_pct=crop_pct)

    dataset.transform = transform

    sampler = None

    # if distributed:
    #     if is_training:
    #         sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    #     else:
    #         # This will add extra duplicate entries to result in equal num
    #         # of samples per-process, will slightly alter validation results
    #         sampler = OrderedDistributedSampler(dataset)

    if collate_fn is None:
        collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
    if xm.xrt_world_size() > 1:
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True)
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=sampler is None and is_training,
        num_workers=num_workers,
        sampler=sampler,
        collate_fn=collate_fn,
        drop_last=is_training,
    )

    if use_prefetcher:
        loader = PrefetchLoader(
            loader,
            rand_erase_prob=rand_erase_prob if is_training else 0.,
            rand_erase_mode=rand_erase_mode,
            mean=mean,
            std=std)

    return loader
コード例 #21
0
def training(rank, world_size, backend, config):
    # Specific xla
    print(xm.get_ordinal(), ": run with config:", config, "- backend=", backend)
    device = xm.xla_device()

    # Data preparation
    dataset = RndDataset(nb_samples=config["nb_samples"])

    # Specific xla
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(),
    )
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=int(config["batch_size"] / xm.xrt_world_size()),
        num_workers=1,
        sampler=train_sampler,
    )

    # Specific xla
    para_loader = pl.MpDeviceLoader(train_loader, device)

    # Model, criterion, optimizer setup
    model = wide_resnet50_2(num_classes=100).to(device)
    criterion = NLLLoss()
    optimizer = SGD(model.parameters(), lr=0.01)

    # Training loop log param
    log_interval = config["log_interval"]

    def _train_step(batch_idx, data, target):

        data = data
        target = target

        optimizer.zero_grad()
        output = model(data)
        # Add a softmax layer
        probabilities = torch.nn.functional.softmax(output, dim=0)

        loss_val = criterion(probabilities, target)
        loss_val.backward()
        xm.optimizer_step(optimizer)

        if batch_idx % log_interval == 0:
            print(
                "Process {}/{} Train Epoch: {} [{}/{}]\tLoss: {}".format(
                    xm.get_ordinal(),
                    xm.xrt_world_size(),
                    epoch,
                    batch_idx * len(data),
                    len(train_sampler),
                    loss_val.item(),
                )
            )
        return loss_val

    # Running _train_step for n_epochs
    n_epochs = 1
    for epoch in range(n_epochs):
        for batch_idx, (data, target) in enumerate(para_loader):
            _train_step(batch_idx, data, target)
コード例 #22
0
ファイル: trainer.py プロジェクト: Matimath/transformers
    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (model_path is not None
                and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
                and os.path.isfile(os.path.join(model_path, "scheduler.pt"))):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt")))
            scheduler.load_state_dict(
                torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        model.to(self.args.device)
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(),
                                       metric_dict={})
        if is_wandb_available():
            self._setup_wandb()

        # Train!
        if is_tpu_available():
            num_examples = len(train_dataloader._loader._loader.dataset)
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size(
            )
        else:
            num_examples = len(train_dataloader.dataset)
            total_train_batch_size = (self.args.train_batch_size *
                                      self.args.gradient_accumulation_steps *
                                      (torch.distributed.get_world_size()
                                       if self.args.local_rank != -1 else 1), )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", num_examples)
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d",
                    self.args.per_gpu_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(epochs_trained,
                                int(num_train_epochs),
                                desc="Epoch",
                                disable=not self.is_local_master())
        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=not self.is_local_master())
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(model, inputs, optimizer)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    if self.args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       self.args.max_grad_norm)

                    if is_tpu_available():
                        xm.optimizer_step(optimizer)
                    else:
                        optimizer.step()

                    scheduler.step()
                    model.zero_grad()
                    global_step += 1

                    if self.is_local_master():
                        if (self.args.logging_steps > 0
                                and global_step % self.args.logging_steps
                                == 0) or (global_step == 1
                                          and self.args.logging_first_step):
                            logs = {}
                            if self.args.evaluate_during_training:
                                results = self.evaluate()
                                for key, value in results.items():
                                    eval_key = "eval_{}".format(key)
                                    logs[eval_key] = value

                            loss_scalar = (tr_loss - logging_loss
                                           ) / self.args.logging_steps
                            learning_rate_scalar = scheduler.get_last_lr()[0]
                            logs["learning_rate"] = learning_rate_scalar
                            logs["loss"] = loss_scalar
                            logging_loss = tr_loss

                            if self.tb_writer:
                                for k, v in logs.items():
                                    self.tb_writer.add_scalar(
                                        k, v, global_step)
                            if is_wandb_available():
                                wandb.log(logs, step=global_step)

                            epoch_iterator.write(
                                json.dumps({
                                    **logs,
                                    **{
                                        "step": global_step
                                    }
                                }))

                        if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
                            # In all cases (even distributed/parallel), self.model is always a reference
                            # to the model we want to save.
                            if hasattr(model, "module"):
                                assert model.module is self.model
                            else:
                                assert model is self.model
                            # Save model checkpoint
                            output_dir = os.path.join(
                                self.args.output_dir,
                                f"{PREFIX_CHECKPOINT_DIR}-{global_step}")

                            self.save_model(output_dir)
                            self._rotate_checkpoints()
                            torch.save(
                                optimizer.state_dict(),
                                os.path.join(output_dir, "optimizer.pt"))
                            torch.save(
                                scheduler.state_dict(),
                                os.path.join(output_dir, "scheduler.pt"))
                            logger.info(
                                "Saving optimizer and scheduler states to %s",
                                output_dir)

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break
            if self.args.tpu_metrics_debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.tb_writer:
            self.tb_writer.close()

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(global_step, tr_loss / global_step)
コード例 #23
0
 def world_size(self) -> int:
     return xm.xrt_world_size()
コード例 #24
0
def train_imagenet():
    print('==> Preparing data..')
    img_dim = get_model_property('img_dim')
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    device = xm.xla_device()
    model = get_model_property('model_fn')().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(),
                          lr=FLAGS.lr,
                          momentum=FLAGS.momentum,
                          weight_decay=1e-4)
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         xm.xrt_world_size())
    lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
        optimizer,
        scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
        scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
        scheduler_divide_every_n_epochs=getattr(
            FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
        num_steps_per_epoch=num_training_steps_per_epoch,
        summary_writer=writer)
    loss_fn = nn.CrossEntropyLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()
            if x % FLAGS.log_steps == 0:
                xm.add_step_closure(_train_update,
                                    args=(device, x, loss, tracker))

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    max_accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print('Finished training epoch {}'.format(epoch))

        para_loader = pl.ParallelLoader(test_loader, [device])
        accuracy = test_loop_fn(para_loader.per_device_loader(device))
        max_accuracy = max(accuracy, max_accuracy)
        test_utils.write_to_summary(writer,
                                    epoch,
                                    dict_to_write={'Accuracy/test': accuracy},
                                    write_xla_metrics=True)
        if FLAGS.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
    return max_accuracy
コード例 #25
0
ファイル: trainer.py プロジェクト: yf1291/nlp4
    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (model_path is not None
                and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
                and os.path.isfile(os.path.join(model_path, "scheduler.pt"))):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"),
                           map_location=self.args.device))
            scheduler.load_state_dict(
                torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(),
                                       metric_dict={})

        # Train!
        if is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size(
            )
        else:
            total_train_batch_size = (self.args.train_batch_size *
                                      self.args.gradient_accumulation_steps *
                                      (torch.distributed.get_world_size()
                                       if self.args.local_rank != -1 else 1))
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d",
                    self.args.per_device_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            self.global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(epochs_trained,
                                int(num_train_epochs),
                                desc="Epoch",
                                disable=not self.is_local_master())
        for epoch in train_iterator:
            if isinstance(train_dataloader, DataLoader) and isinstance(
                    train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(
                    train_dataloader,
                    [self.args.device]).per_device_loader(self.args.device)
                epoch_iterator = tqdm(parallel_loader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())
            else:
                epoch_iterator = tqdm(train_dataloader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())

            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(model, inputs, optimizer)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    if self.args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       self.args.max_grad_norm)

                    if is_torch_tpu_available():
                        xm.optimizer_step(optimizer)
                    else:
                        optimizer.step()

                    scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0
                            and self.global_step % self.args.logging_steps
                            == 0) or (self.global_step == 1
                                      and self.args.logging_first_step):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss -
                                        logging_loss) / self.args.logging_steps
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >=
                            version.parse("1.4") else scheduler.get_lr()[0])
                        logging_loss = tr_loss

                        self._log(logs)

                    if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
                        self.evaluate()

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert model.module is self.model
                        else:
                            assert model is self.model
                        # Save model checkpoint
                        output_dir = os.path.join(
                            self.args.output_dir,
                            f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")

                        self.save_model(output_dir)

                        if self.is_world_master():
                            self._rotate_checkpoints()

                        if is_torch_tpu_available():
                            xm.rendezvous("saving_optimizer_states")
                            xm.save(optimizer.state_dict(),
                                    os.path.join(output_dir, "optimizer.pt"))
                            xm.save(scheduler.state_dict(),
                                    os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_master():
                            torch.save(
                                optimizer.state_dict(),
                                os.path.join(output_dir, "optimizer.pt"))
                            torch.save(
                                scheduler.state_dict(),
                                os.path.join(output_dir, "scheduler.pt"))

                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                train_iterator.close()
                break
            if self.args.tpu_metrics_debug or self.args.debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.tb_writer:
            self.tb_writer.close()
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(self.global_step, tr_loss / self.global_step)
コード例 #26
0
def run(epochs, batch_size, num_workers, learning_rate, warmup_steps,
        pretrained_model, dropout):

    # datasets, samplers and dataloaders
    trainset = JigsawDataset(input_ids=train_input_ids,
                             token_type_ids=train_token_type_ids,
                             attention_mask=train_attention_mask,
                             targets=train_targets)

    validset = JigsawDataset(input_ids=valid_input_ids,
                             token_type_ids=valid_token_type_ids,
                             attention_mask=valid_attention_mask,
                             targets=valid_targets)

    # samplers
    trainsampler = DistributedSampler(dataset=trainset,
                                      num_replicas=xm.xrt_world_size(),
                                      rank=xm.get_ordinal(),
                                      shuffle=True)

    validsampler = DistributedSampler(dataset=validset,
                                      num_replicas=xm.xrt_world_size(),
                                      rank=xm.get_ordinal(),
                                      shuffle=False)

    # dataloaders
    trainloader = DataLoader(
        dataset=trainset,
        batch_size=batch_size,
        sampler=trainsampler,
        num_workers=num_workers,
        drop_last=True,
    )

    validloader = DataLoader(dataset=validset,
                             batch_size=batch_size,
                             sampler=validsampler,
                             drop_last=True)

    xm.master_print(f"Loading datasets....Complete!")

    # model
    device = xm.xla_device()
    model = BertBaseUncased(pretrained_model, dropout)
    model = model.to(device)
    xm.master_print(f"Loading model....Complete!")

    # training_parameters, optimizers and schedulers
    not_decay = ['LayerNorm.weight', 'LayerNorm.bias', 'bias']

    parameters = list(model.named_parameters())

    train_parameters = [{
        'params':
        [p for n, p in parameters if not any(nd in n for nd in not_decay)],
        'weight_decay':
        0.001
    }, {
        'params':
        [p for n, p in parameters if any(nd in n for nd in not_decay)],
        'weight_decay':
        0.001
    }]

    num_training_steps = int(len(trainset) / xm.xrt_world_size())
    xm.master_print(f"Iterations per epoch: {num_training_steps}")

    optimizer = AdamW(train_parameters, lr=learning_rate)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=num_training_steps)

    # training and evaluation
    for epoch in range(epochs):
        # train
        para_train_loader = pl.ParallelLoader(trainloader, [device])

        start_time = time.time()

        train_loss = train_fn(model,
                              para_train_loader.per_device_loader(device),
                              optimizer,
                              device,
                              scheduler=scheduler)

        end_time = time.time()
        time_per_epoch = end_time - start_time
        xm.master_print(f"Time taken: {time_per_epoch}")

        xm.master_print(
            f"epoch: {epoch+1}/{epochs}, train loss: {np.mean(train_loss):.4f}"
        )

        # eval
        para_valid_loader = pl.ParallelLoader(validloader, [device])
        outputs, targets = valid_fn(
            para_valid_loader.per_device_loader(device), model, device)

        auc = metrics.roc_auc_score(np.array(targets) > 0.5, outputs)
        xm.master_print(f"auc_score: {auc:.4f}")
コード例 #27
0
def train_mnist(flags, **kwargs):
    torch.manual_seed(1)

    if flags.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(flags.batch_size, 1, 28, 28),
                  torch.zeros(flags.batch_size, dtype=torch.int64)),
            sample_count=60000 // flags.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(flags.batch_size, 1, 28, 28),
                  torch.zeros(flags.batch_size, dtype=torch.int64)),
            sample_count=10000 // flags.batch_size // xm.xrt_world_size())
    else:
        train_dataset = datasets.MNIST(os.path.join(flags.datadir,
                                                    str(xm.get_ordinal())),
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        test_dataset = datasets.MNIST(os.path.join(flags.datadir,
                                                   str(xm.get_ordinal())),
                                      train=False,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=flags.batch_size,
            sampler=train_sampler,
            drop_last=flags.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=flags.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=flags.batch_size,
            drop_last=flags.drop_last,
            shuffle=False,
            num_workers=flags.num_workers)

    # Scale learning rate to num cores
    lr = flags.lr * xm.xrt_world_size()

    device = xm.xla_device()
    model = MNIST().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(flags.logdir)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum)
    loss_fn = nn.NLLLoss()

    # Start up client side profiler server.
    server = xp.start_server(flags.profiler_port)
    # Testing purpose only: set event for synchronization.
    if kwargs.get('worker_started'):
        kwargs.pop('worker_started').set()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            with xp.StepTrace('train_mnist', step_num=step):
                with xp.Trace('build_graph'):
                    optimizer.zero_grad()
                    output = model(data)
                    loss = loss_fn(output, target)
                    loss.backward()
                xm.optimizer_step(optimizer)

                tracker.add(flags.batch_size)
                if step % flags.log_steps == 0:
                    xm.add_step_closure(_train_update,
                                        args=(device, step, loss, tracker,
                                              writer))

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            with xp.StepTrace('test_mnist'):
                output = model(data)
                pred = output.max(1, keepdim=True)[1]
                correct += pred.eq(target.view_as(pred)).sum()
                total_samples += data.size()[0]

        accuracy = 100.0 * correct.item() / total_samples
        accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)
    accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, flags.num_epochs + 1):
        xm.master_print('Epoch {} train begin {}'.format(
            epoch, test_utils.now()))
        train_loop_fn(train_device_loader)
        xm.master_print('Epoch {} train end {}'.format(epoch,
                                                       test_utils.now()))

        accuracy = test_loop_fn(test_device_loader)
        xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
            epoch, test_utils.now(), accuracy))
        max_accuracy = max(accuracy, max_accuracy)
        test_utils.write_to_summary(writer,
                                    epoch,
                                    dict_to_write={'Accuracy/test': accuracy},
                                    write_xla_metrics=True)
        if flags.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
    return max_accuracy
コード例 #28
0
def tpu_distributed() -> bool:
    if _TPU_AVAILABLE:
        return xm.xrt_world_size() > 1
    return False
コード例 #29
0
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
    if xm.xrt_world_size() <= 1:
        return SequentialSampler(dataset)
    return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False)
コード例 #30
0
def train_cifar():
    print('==> Preparing data..')

    if FLAGS.fake_data:
        train_dataset_len = 50000  # Number of example in CIFAR train set.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, 32, 32),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, 32, 32),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                                     train=True,
                                                     download=True,
                                                     transform=transform_train)
        train_dataset_len = len(train_dataset)
        test_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                                    train=False,
                                                    download=True,
                                                    transform=transform_test)
        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            sampler=test_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    model = torchvision.models.resnet18 if FLAGS.use_torchvision else ResNet18
    model_parallel = dp.DataParallel(model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(model.parameters(),
                                           lr=FLAGS.lr,
                                           momentum=FLAGS.momentum,
                                           weight_decay=5e-4))
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    writer = test_utils.get_summary_writer(FLAGS.logdir)
    num_devices = len(
        xm.xla_replication_devices(devices)) if len(devices) > 1 else 1
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         num_devices)
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = mean(accuracies)
        print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy))
        global_step = (epoch - 1) * num_training_steps_per_epoch
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         global_step)
        if FLAGS.metrics_debug:
            print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    return accuracy