예제 #1
0
 def __init__(self, network, device_ids=None, batchdim=0, drop_last=False):
   if device_ids is None:
     device_ids = xm.get_xla_supported_devices()
   self._device_ids = list(device_ids)
   self._batchdim = batchdim
   self._drop_last = drop_last
   self._native_run = False
   if len(self._device_ids) > 1:
     replication_devices = xm.xla_replication_devices(self._device_ids)
     self._replication = xm.Replication(self._device_ids, replication_devices)
   else:
     self._replication = None
   self._models = []
   self._contexts = []
   module = network if isinstance(network, torch.nn.Module) else network()
   for device in device_ids:
     device_module = deepcopy(module).to(device=torch.device(device))
     self._models.append(device_module)
     self._contexts.append(Context(torch.device(device)))
   if not self._models:
     # No XLA device, push a vanilla network in.
     device = self._get_model_device(module)
     self._models.append(module)
     self._device_ids.append(device)
     self._contexts.append(Context(torch.device(device)))
     self._native_run = True
예제 #2
0
def run_benchmark(args, pos_args):
    devices = xm.get_xla_supported_devices(max_devices=args.max_devices)
    shape = [int(x) for x in args.shape.split(',')]

    send_list = []
    for i in range(0, len(devices)):
        mb = []
        for j in range(0, args.prefetch):
            mb.append(torch.randn(*shape))
        send_list.append(mb)

    def threadfn(i):
        device = devices[i]
        xdevices = [device] * len(send_list[i])
        for n in range(0, args.test_count):
            with xu.TimedScope(msg='Send[{}][{}]: '.format(i, n),
                               printfn=print):
                _ = torch_xla._XLAC._xla_tensors_from_aten(
                    send_list[i], xdevices)

    threads = []
    for i in range(0, len(devices)):
        t = threading.Thread(target=threadfn, args=(i, ))
        t.start()
        threads.append(t)
    for t in threads:
        t.join()
    print(torch_xla._XLAC._xla_metrics_report())
예제 #3
0
    def test(self):
        devices = xm.get_xla_supported_devices()
        batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4)
        sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(batch_size, 3, 224,
                              224), torch.zeros(batch_size,
                                                dtype=torch.int64)),
            sample_count=sample_count * len(devices))

        def loop_fn(model, loader, device, context):
            loss_fn = nn.NLLLoss()
            optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

            for x, (data, target) in loader:
                with xu.TimedScope(msg='Training loop: ', printfn=None):
                    optimizer.zero_grad()
                    output = xu.timed(lambda: model(data),
                                      msg='Model: ',
                                      printfn=None)
                    loss = xu.timed(lambda: loss_fn(output, target),
                                    msg='Loss: ',
                                    printfn=None)
                    xu.timed(loss.backward, msg='LossBkw: ', printfn=None)
                    xu.timed(lambda: xm.optimizer_step(optimizer),
                             msg='Step: ',
                             printfn=None)
                    self.assertLess(loss.cpu().item(), 3.0)

        model_parallel = dp.DataParallel(torchvision.models.resnet18,
                                         device_ids=devices)
        model_parallel(loop_fn, train_loader)
예제 #4
0
 def test(self):
   devices = xm.get_xla_supported_devices()
   A = 3.11
   B = 4.09
   batch_size = 128 * len(devices)
   gen = xu.FnDataGenerator(
       lambda x: x * A + B, batch_size, _gen_tensor, dims=[8], count=10)
   para_loader = pl.ParallelLoader(gen, batch_size, devices)
   for x, (data, target) in para_loader:
     for device in devices:
       dx = para_loader.to(data, device)
       self.assertEqual(dx.device, torch.device(device))
예제 #5
0
파일: data_parallel.py 프로젝트: Saiuz/xla
 def __init__(self, network, device_ids=None, batchdim=0, drop_last=False):
     if device_ids is None:
         device_ids = xm.get_xla_supported_devices()
     self._batchdim = batchdim
     self._drop_last = drop_last
     self._device_ids = list(device_ids)
     self._replication = (xm.Replication(self._device_ids)
                          if self._device_ids else None)
     self._models = []
     for device in device_ids:
         module = network().to(device=torch.device(device))
         self._models.append(module)
     if not self._models:
         # No XLA device, push a vanilla network in.
         self._models.append(network())
예제 #6
0
 def test(self):
     devices = xm.get_xla_supported_devices()
     for device in reversed(devices):
         t = _gen_tensor(8, 12)
         tto = t.to(device=torch.device(device))
         self.assertEqual(tto.device, torch.device(device))
     t = _gen_tensor(8, 12).to(device=torch.device(devices[0]))
     for device in devices[1:]:
         tto = t.to(device=torch.device(device))
         self.assertEqual(tto.device, torch.device(device))
     for i in range(0, len(devices) - 1):
         dev0 = devices[i]
         dev1 = devices[i + 1]
         t0 = torch.zeros(4, 4, device=torch.device(dev0))
         t1 = t0.to(device=torch.device(dev1))
         t0 = t0 + torch.ones_like(t0, device=torch.device(dev0))
         t1 = t1 + torch.ones_like(t1, device=torch.device(dev1))
         self.assertEqual(t0.cpu(), t1.cpu())
예제 #7
0
 def __init__(self, network, device_ids=None, batchdim=0, drop_last=False):
   if device_ids is None:
     device_ids = xm.get_xla_supported_devices()
   self._device_ids = list(device_ids)
   self._batchdim = batchdim
   self._drop_last = drop_last
   replication_devices = (
       xm.xla_replication_devices(self._device_ids)
       if self._device_ids else None)
   self._replication = (
       xm.Replication(self._device_ids, replication_devices)
       if replication_devices else None)
   self._models = []
   module = network if isinstance(network, torch.nn.Module) else network()
   for device in device_ids:
     device_module = deepcopy(module).to(device=torch.device(device))
     self._models.append(device_module)
   if not self._models:
     # No XLA device, push a vanilla network in.
     self._models.append(network())
예제 #8
0
def main():
    parser = utils.get_args_parser_with_general_args()
    parser.add_argument(
        '--one_tpu',
        action='store_true',
        help=
        "Run on one tpu core for degugging. Makes it easy to use break points")
    parser.add_argument('--tpu_report',
                        action='store_true',
                        help="Print xla metric report")
    args = parser.parse_args()

    utils.init(args)  # set seeds, init logger, prepare output directory

    devices = tpu_xm.get_xla_supported_devices()
    if args.one_tpu:
        devices = [devices[0]]
    n_tpu = len(devices)
    logging.info(f'Found {n_tpu} TPU cores')

    tokenizer = AutoTokenizer.from_pretrained(args.bert_model)
    tokenizer.save_pretrained(args.output_dir)

    args.start_epoch = utils.prepare_last_checkpoint(args.bert_model)
    model = AutoModelWithLMHead.from_pretrained(
        args.bert_model)  # Only Masked Language Modeling
    logging.info(f"Saving initial checkpoint to: {args.output_dir}")
    model.save_pretrained(args.output_dir)
    model = tpu_dp.DataParallel(model, device_ids=devices)

    num_data_epochs, num_train_optimization_steps = utils.get_dataset_stats(
        args, n_tpu)

    def tpu_training_loop(model, loader, device, context):
        """ Called by torch_xla_py.data_parallel. This function is executed on each core of the TPU once per epoch"""

        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        # one optimizer and scheduler per TPU core. Both objects are saved in `context` to be reused the next epoch
        optimizer = context.getattr_or(
            'optimizer',
            AdamW(optimizer_grouped_parameters,
                  lr=args.learning_rate,
                  eps=args.adam_epsilon,
                  betas=tuple(args.betas)))

        # derive warmup info
        if args.warmup_proportion is not None:
            warmup_steps = int(args.warmup_proportion *
                               num_train_optimization_steps + 0.5)
        elif args.warmup_steps is not None:
            warmup_steps = args.warmup_steps
        else:
            raise Exception(
                'What is the warmup?? Specify either warmup proportion or steps'
            )
        scheduler = context.getattr_or(
            'scheduler',
            WarmupLinearSchedule(optimizer,
                                 warmup_steps=warmup_steps,
                                 t_total=num_train_optimization_steps))

        tr_loss = None
        pbar = None
        if str(pbar_device) == str(
                device
        ):  # All threads are in sync. Use progress bar only on one of them
            pbar = tqdm(total=int(pbar_steps),
                        desc=f"device {device}",
                        dynamic_ncols=True)

        tracker = tpu_xm.RateTracker()

        model.train()
        for step, batch in loader:
            input_ids, input_mask, segment_ids, lm_label_ids, _ = batch
            outputs = model(input_ids, segment_ids, input_mask, lm_label_ids)
            loss = outputs[0]
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss.backward()
            tracker.add(args.train_batch_size)

            tr_loss = loss * args.gradient_accumulation_steps if step == 0 else tr_loss + loss * args.gradient_accumulation_steps
            if pbar is not None:
                pbar.update(1)
                # pbar.set_description(desc=f'LR: {scheduler.get_lr()}')
            if (step + 1) % args.gradient_accumulation_steps == 0:
                tpu_xm.optimizer_step(optimizer)
                prev_lr = scheduler.get_last_lr()[0]
                scheduler.step()
                curr_lr = scheduler.get_last_lr()[0]
                if args.track_learning_rate:
                    if pbar is not None:
                        pbar.set_description(
                            f"Prev LR: {prev_lr} Curr LR: {curr_lr}")
                optimizer.zero_grad()

        return tr_loss.item(
        ) / step  # `.item()` requires a trip from TPU to CPU, which is very slow. Use it only once per epoch=

    for epoch in range(args.start_epoch, args.epochs):
        # Load one training file into memory
        epoch_dataset = utils.PregeneratedDataset(
            epoch=epoch,
            training_path=args.pregenerated_data,
            tokenizer=tokenizer,
            num_data_epochs=num_data_epochs,
            reduce_memory=args.reduce_memory)
        train_sampler = RandomSampler(epoch_dataset)
        train_dataloader = DataLoader(epoch_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        pbar_device = devices[0]
        pbar_steps = utils.compute_num_steps_in_epoch(
            num_samples=train_sampler.num_samples,
            batch_size=args.train_batch_size,
            grad_accum_steps=
            1,  # the pbar steps should not take into account grad accumulation steps
            n_tpu=n_tpu)
        logging.info(
            f'start training, epoch {epoch} on {len(devices)} cores for {pbar_steps} steps'
        )
        start = time.time()
        losses = model(
            tpu_training_loop, train_dataloader
        )  # calls `tpu_training_loop` multiple times, once per TPU core
        logging.info(
            f'Epoch {epoch} took {round(time.time() - start, 2)} seconds. Average loss: {sum(losses)/len(losses)}'
        )
        utils.save_checkpoint(model._models[0], epoch, args.output_dir)

    if args.tpu_report:
        logging.info(torch_xla._XLAC._xla_metrics_report())
예제 #9
0
def train_mnist():
    torch.manual_seed(1)
    # Step 1: init data folders
    print("init data folders", flush=True)
    # init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tgtpu.china_drinks_sku_folders(
        DATASET_FOLDER, SAMPLE_NUM_PER_CLASS, QUERY_NUM_PER_CLASS,
        VALIDATION_SPLIT_PERCENTAGE)

    devices = xm.get_xla_supported_devices(max_devices=FLAGS.num_cores)
    # Scale learning rate to num cores
    lr = FLAGS.lr * len(devices)
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    model_parallel = dp.DataParallel(CNN_Plus_RNEncoder, device_ids=devices)

    degrees = random.choice([0, 90, 180, 270])
    train_task = tgtpu.ChinaDrinksTask(metatrain_character_folders, CLASS_NUM,
                                       SAMPLE_NUM_PER_CLASS,
                                       QUERY_NUM_PER_CLASS)
    train_sample_batch_dataloader = tgtpu.get_data_loader(
        train_task,
        image_size=IMAGE_SIZE,
        sample_num_per_class=SAMPLE_NUM_PER_CLASS,
        query_num_per_class=QUERY_NUM_PER_CLASS,
        train_shuffle=False,
        query_shuffle=True,
        rotation=degrees,
        num_workers=NO_OF_TPU_CORES)

    test_task = tgtpu.ChinaDrinksTask(metatest_character_folders, CLASS_NUM,
                                      SAMPLE_NUM_PER_CLASS,
                                      SAMPLE_NUM_PER_CLASS)
    test_sample_test_dataloader = tgtpu.get_data_loader(
        test_task,
        IMAGE_SIZE,
        sample_num_per_class=SAMPLE_NUM_PER_CLASS,
        query_num_per_class=QUERY_NUM_PER_CLASS,
        train_shuffle=False,
        query_shuffle=True,
        rotation=degrees,
        num_workers=NO_OF_TPU_CORES)

    def train_loop_fn(model, loader, device, context):
        relation_network = model
        #relation_network.apply(weights_init)

        relation_network_optim = torch.optim.Adam(
            relation_network.parameters(), lr=LEARNING_RATE)
        relation_network_scheduler = StepLR(relation_network_optim,
                                            step_size=100000,
                                            gamma=0.5)
        mse = nn.MSELoss()
        tracker = xm.RateTracker()

        for x, (samples, sample_labels, batches, batch_labels) in loader:

            relation_network_scheduler.step(episode)

            relation_network.zero_grad()
            #relation_network_optim.zero_grad()
            relation_scores = relation_network(Variable(samples),
                                               Variable(batches))
            relations = relation_scores.view(-1, CLASS_NUM)
            one_hot_labels = Variable(
                torch.zeros(QUERY_NUM_PER_CLASS * CLASS_NUM,
                            CLASS_NUM).scatter_(1, batch_labels.view(-1, 1),
                                                1))
            loss = mse(relations, one_hot_labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)
            xm.optimizer_step(relation_network_optim)
            tracker.add(FLAGS.batch_size)
            print('Debug: ', x, loss.item())
            if x % FLAGS.log_steps == 0:
                print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(
                    device, x, loss.item(), tracker.rate()))

    def test_loop_fn(model, loader, device, context):
        relation_network = model
        total_rewards = 0
        for x, (samples, sample_labels, batches, batch_labels) in loader:
            relation_scores = relation_network(Variable(samples),
                                               Variable(batches))
            relations = relation_scores.view(-1, CLASS_NUM)
            _, predict_labels = torch.max(relations.data, 1)
            rewards = [
                1 if predict_labels[j] == test_labels[j] else 0
                for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)
            ]
            total_rewards += np.sum(rewards)

        test_accuracy = total_rewards / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS / TEST_EPISODE
        print('[{}] Accuracy={:.2f}%'.format(device, 100 * test_accuracy))
        return test_accuracy

    accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_sample_batch_dataloader)
        accuracies = model_parallel(test_loop_fn, test_sample_test_dataloader)
        accuracy = sum(accuracies) / len(devices)
        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy * 100.0
예제 #10
0
        logging.warning(f'This will get logged to file: {args.log_file}')
    else:
        logging.basicConfig(level=logging.INFO, format=log_format)

    # create output dir
    if os.path.exists(args.output_dir):
        y_or_n = input(
            f'Output Dir {args.output_dir} already exists.  Write to same dir? (y/n)'
        )
        if y_or_n != 'y':
            raise Exception('Set new output dir')
    else:
        os.makedirs(args.output_dir, exist_ok=True)

    # TPU devices
    devices = tpu_xm.get_xla_supported_devices()
    if args.one_tpu:
        devices = [devices[0]]
    n_tpu = len(devices)
    logging.info(f'Found {n_tpu} TPU cores')

    # set seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.bert_model)
    logging.info(f"Saving tokenizer to: {args.output_dir}")
    tokenizer.save_pretrained(args.output_dir)
예제 #11
0
def train_cifar():
    print('==> Preparing data..')

    transform_train = transforms.Compose([
        transforms.Lambda(lambda x: RandomPixelPad(x, padding=4)),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        Cutout(18, random_pixel=True),  # add Cutout
        transforms.Normalize((0.5071, 0.4865, 0.4409),
                             (0.2673, 0.2564, 0.2762)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4865, 0.4409),
                             (0.2673, 0.2564, 0.2762)),
    ])

    trainset = torchvision.datasets.CIFAR100(root=FLAGS.datadir,
                                             train=True,
                                             download=True,
                                             transform=transform_train)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=FLAGS.batch_size,
                                               shuffle=True,
                                               num_workers=FLAGS.num_workers)

    testset = torchvision.datasets.CIFAR100(root=FLAGS.datadir,
                                            train=False,
                                            download=True,
                                            transform=transform_test)
    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=FLAGS.batch_size,
                                              shuffle=False,
                                              num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])

    # Define model here
    model = WRN_McDonnell(20, 10, 100, binarize=True)

    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    model_parallel = dp.DataParallel(model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(model.parameters(),
                                           lr=FLAGS.lr,
                                           momentum=FLAGS.momentum,
                                           weight_decay=5e-4))

        # LR scheduler
        scheduler = context.getattr_or(
            'scheduler',
            lambda: CosineAnnealingRestartsLR(optimizer, T=2, eta_min=1e-4))

        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            if x % FLAGS.log_steps == 0:
                print('[{}]({}) Loss={:.5f}'.format(device, x, loss.item()))

        # Step LR scheduler
        scheduler.step()

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        return correct / total_samples

    best_accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = sum(accuracies) / len(devices)

        print('Epoch {}, Accuracy={:.2f}%'.format(epoch, 100.0 * accuracy))

        # Keep track of best model
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model_parallel._models[0].state_dict(), 'model.pt')

        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy * 100.0
예제 #12
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The train file path")
    parser.add_argument("--eval_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The dev file path")
    parser.add_argument("--predict_file",
                        default=None,
                        type=str,
                        required=False,
                        help="The predict file path")
    parser.add_argument("--predict_result_file",
                        default=None,
                        type=str,
                        required=False,
                        help="The predict result file path")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help=
        "The config json file corresponding to the pre-trained BERT model. \n"
        "This specifies the model architecture.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )
    parser.add_argument(
        "--init_checkpoint",
        default=None,
        type=str,
        help="Initial checkpoint (usually from a pre-trained BERT model).")
    parser.add_argument(
        "--do_lower_case",
        default=False,
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument(
        "--max_seq_length",
        default=300,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_eval",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--num_labels",
                        default=1,
                        type=int,
                        help="mapping classify nums")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=6.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )

    args = parser.parse_args()
    vocab_path = os.path.join(args.bert_model, VOCAB_NAME)
    # bert_config = BertConfig.from_json_file(vocab_path)
    data_processor = DataProcessor()
    devices = tpu_xm.get_xla_supported_devices()
    n_tpu = len(devices)
    logging.info(f'Found {n_tpu} TPU cores')
    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.do_train:
        if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
            raise ValueError(
                "Output directory ({}) already exists and is not empty.".
                format(args.output_dir))
        else:
            os.makedirs(args.output_dir, exist_ok=True)

    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_path,
                                           do_lower_case=args.do_lower_case)
    model = BertForSequenceClassification.from_pretrained(args.bert_model,
                                                          num_labels=3)
    for k, v in model.state_dict().items():
        print(f'k = {k}, v.grad = {v.grad}')

    model = tpu_dp.DataParallel(model, device_ids=devices)

    if args.do_train:
        # 数据读取
        train_examples = data_processor.get_examples(args.train_file,
                                                     data_type='train')
        eval_examples = data_processor.get_examples(args.eval_file,
                                                    data_type='eval')

        # 特征转换
        train_features = convert_examples_to_features(args, train_examples,
                                                      args.max_seq_length,
                                                      tokenizer)
        eval_features = convert_examples_to_features(args, eval_examples,
                                                     args.max_seq_length,
                                                     tokenizer)
        num_train_steps = int(
            len(train_features) // args.train_batch_size //
            args.gradient_accumulation_steps * args.num_train_epochs)

        # 数据loader
        train_loader = ParaDataloader(train_features)
        eval_loader = ParaDataloader(eval_features)

        # 数据并行loader输入格式
        train_loader = DataLoader(train_loader,
                                  shuffle=True,
                                  batch_size=args.train_batch_size)
        eval_loader = DataLoader(eval_loader,
                                 shuffle=False,
                                 batch_size=args.eval_batch_size)

    def tpu_training_loop(model, loader, device, context):
        """ Called by torch_xla_py.data_parallel. This function is executed on each core of the TPU once per epoch"""
        model.zero_grad()
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        param_optimizer = list(model.named_parameters())

        optimizer_grouped_parameters = [{
            'params': [p for n, p in param_optimizer if n not in no_decay],
            'weight_decay_rate':
            0.01
        }, {
            'params': [p for n, p in param_optimizer if n in no_decay],
            'weight_decay_rate':
            0.0
        }]
        optimizer = context.getattr_or(
            'optimizer',
            BertAdam(optimizer_grouped_parameters,
                     lr=args.learning_rate,
                     warmup=args.warmup_proportion,
                     t_total=num_train_steps))
        tr_loss = None
        pbar = None
        if str(pbar_device) == str(device):
            pbar = tqdm(total=int(pbar_steps),
                        desc=f"training",
                        dynamic_ncols=True)
        tracker = tpu_xm.RateTracker()
        model.train()
        for step, batch in enumerate(loader):
            input_ids, input_mask, segment_ids, label_ids = batch
            loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss.backward()
            tracker.add(args.train_batch_size)
            tr_loss = loss * args.gradient_accumulation_steps if step == 0 else tr_loss + loss * args.gradient_accumulation_steps
            if pbar is not None:
                pbar.update(1)
            tpu_xm.optimizer_step(optimizer)
            # optimizer.step()
            optimizer.zero_grad()
        return tr_loss.item() / step

    def tpu_evaluating_loop(model, eval_dataloader, device, context):
        model.eval()
        eval_loss = 0
        eval_pbar = None
        logits, labels = [], []
        if str(pbar_device) == str(device):
            eval_pbar = tqdm(total=int(eval_pbar_steps),
                             desc=f"evaluating",
                             dynamic_ncols=True)
        tracker = tpu_xm.RateTracker()
        for step, batch in enumerate(eval_dataloader):
            input_ids, input_mask, segment_ids, label_ids = batch
            with torch.no_grad():
                loss, logit = model(input_ids, segment_ids, input_mask,
                                    label_ids)
            eval_loss = loss * args.gradient_accumulation_steps if step == 0 else eval_loss + loss * args.gradient_accumulation_steps
            logit = torch.argmax(logit, dim=-1)
            logits.extend(logit.tolist())
            labels.extend(label_ids.tolist())
            tracker.add(args.eval_batch_size)
            if eval_pbar is not None:
                eval_pbar.update(1)
        return (eval_loss.item() / step, logits, labels)

    def tpu_predicting_loop(model, dataloader, device, context):
        model.eval()
        eval_pbar = None
        logits, example_ids, probs = [], [], []
        if str(pbar_device) == str(device):
            eval_pbar = tqdm(total=int(eval_pbar_steps),
                             desc=f"evaluating",
                             dynamic_ncols=True)
        tracker = tpu_xm.RateTracker()
        for step, batch in enumerate(dataloader):
            input_ids, input_mask, segment_ids, label_ids = batch
            with torch.no_grad():
                logit = model(input_ids, segment_ids, input_mask)
            prob = torch.softmax(logit, dim=-1).tolist()
            logit = torch.argmax(logit, dim=-1)
            logits.extend(logit.tolist())
            example_ids.extend(label_ids.tolist())
            probs.extend(prob)
            tracker.add(args.eval_batch_size)
            if eval_pbar is not None:
                eval_pbar.update(1)
        return logits, example_ids, probs

    def eval_meric(model, loop, data_loader):
        eval_results = model(loop, data_loader)
        eval_loss, eval_loss = 0, 0
        all_logits, all_labels = [], []
        assert len(eval_results) == len(devices) == 8
        for eval_result in eval_results:
            eval_loss += eval_result[0]
            all_logits.extend(eval_result[1])
            all_labels.extend(eval_result[2])
        accuracy(all_labels, all_logits)
        logger.info(f'Average eval loss = {eval_loss / len(eval_results)}')

    def write_predict_file(model, loop, data_loader, file_path):
        """
        写入预测文件: 格式:'五彩滨云-final.csv'
        """
        results = model(loop, data_loader)
        logits, ids, probs = [], [], []
        assert len(results) == len(devices) == 8
        for result in results:
            logits.extend(result[0])
            ids.extend(result[1])
            probs.extend(result[2])
        assert len(ids) == len(logits)
        logger.info(
            f'zero nums {logits.count(0)}, one nums {logits.count(1)}, two nums {logits.count(2)}'
        )
        labels = [
            data_processor.eval_dict[id][1] for id, logit in zip(ids, logits)
        ]
        if not args.do_eval:
            logits = [i - 1 for i in logits]
            data_df = pd.DataFrame({'id': ids, 'y': logits})
            data_df1 = pd.DataFrame({'id': ids, 'y': logits, 'probs': probs})
            data_df1.to_csv('probs_predict.csv', index=None)
        else:
            assert len(labels) == len(logits)
            accuracy(labels, logits)
            passages = [
                data_processor.eval_dict[id][0]
                for id, logit in zip(ids, logits)
            ]
            assert len(labels) == len(passages)
            match_array = np.array((logits)) == np.array(labels)
            match_list = match_array.tolist()
            data_df = pd.DataFrame({
                'id': ids,
                'pred': logits,
                'real': labels,
                'probs': probs,
                'match': match_list,
                'passage': passages
            })
        data_df.to_csv(file_path, index=None)

    if args.do_train:
        for epoch in range(1, int(args.num_train_epochs) + 1, 1):
            pbar_device = devices[0]
            logger.info(f'Start to evaluate......')
            eval_pbar_steps = len(eval_loader) // n_tpu
            eval_meric(model, tpu_evaluating_loop, eval_loader)
            pbar_steps = len(train_loader) // n_tpu
            logging.info(
                f'Start training, epoch {epoch} on {len(devices)} cores for {pbar_steps} steps'
            )
            start = time.time()
            losses = model(tpu_training_loop, train_loader)
            logging.info(
                f'Epoch {epoch} took {round(time.time() - start, 2)} seconds. average train loss: {sum(losses) / len(losses)}'
            )
            save_checkpoint(model._models[0], epoch, args.output_dir)
            logger.info('Train finished......')

    elif args.do_predict:
        pbar_device = devices[0]
        logger.info(f'Start to predict......')
        if args.do_eval:
            predict_examples = data_processor.get_eval_examples(args.eval_file)
        else:
            predict_examples = data_processor.get_predict_examples(
                args.predict_file)

        predict_features = convert_examples_to_features(
            args, predict_examples, args.max_seq_length, tokenizer)
        predict_loader = ParaDataloader(predict_features)
        predict_loader = DataLoader(predict_loader,
                                    shuffle=False,
                                    batch_size=args.eval_batch_size)
        eval_pbar_steps = len(predict_loader) // n_tpu
        write_predict_file(model, tpu_predicting_loop, predict_loader,
                           args.predict_result_file)
예제 #13
0
def train_mnist():
  torch.manual_seed(1)

  if FLAGS.fake_data:
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.batch_size, 1, 28,
                          28), torch.zeros(FLAGS.batch_size,
                                           dtype=torch.int64)),
        sample_count=60000 // FLAGS.batch_size)
    test_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.batch_size, 1, 28,
                          28), torch.zeros(FLAGS.batch_size,
                                           dtype=torch.int64)),
        sample_count=10000 // FLAGS.batch_size)
  else:
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            FLAGS.datadir,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
        batch_size=FLAGS.batch_size,
        shuffle=True,
        num_workers=FLAGS.num_workers)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            FLAGS.datadir,
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
        batch_size=FLAGS.batch_size,
        shuffle=True,
        num_workers=FLAGS.num_workers)

  devices = (
      xm.get_xla_supported_devices(
          max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
  # Scale learning rate to num cores
  lr = FLAGS.lr * max(len(devices), 1)
  # Pass [] as device_ids to run using the PyTorch/CPU engine.
  model_parallel = dp.DataParallel(MNIST, device_ids=devices)

  def train_loop_fn(model, loader, device, context):
    loss_fn = nn.NLLLoss()
    optimizer = context.getattr_or(
        'optimizer',
        lambda: optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum))
    tracker = xm.RateTracker()

    model.train()
    for x, (data, target) in loader:
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS.batch_size)
      if x % FLAGS.log_steps == 0:
        print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(device, x, loss.item(),
                                                        tracker.rate()))

  def test_loop_fn(model, loader, device, context):
    total_samples = 0
    correct = 0
    model.eval()
    for x, (data, target) in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    print('[{}] Accuracy={:.2f}%'.format(device,
                                         100.0 * correct / total_samples))
    return correct / total_samples

  accuracy = 0.0
  for epoch in range(1, FLAGS.num_epochs + 1):
    model_parallel(train_loop_fn, train_loader)
    accuracies = model_parallel(test_loop_fn, test_loader)
    accuracy = sum(accuracies) / len(accuracies)
    if FLAGS.metrics_debug:
      print(torch_xla._XLAC._xla_metrics_report())

  return accuracy * 100.0
예제 #14
0
                 update_freq=[1],
                 upsample_primary=16,
                 user_dir=None,
                 valid_subset='valid',
                 validate_interval=1,
                 warmup_init_lr=1e-07,
                 warmup_updates=4000,
                 weight_decay=0.0)

task = tasks.setup_task(args)
task.load_dataset(args.train_subset, combine=True, epoch=0)
for valid_sub_split in args.valid_subset.split(','):
    task.load_dataset(valid_sub_split, combine=True, epoch=0)

#devices = xm.get_xla_supported_devices(max_devices=8) # Got error for max devices argument :(
devices = xm.get_xla_supported_devices()
model_parallel = dp.DataParallel(lambda: task.build_model(args),
                                 device_ids=devices)

#max_positions = utils.resolve_max_positions(
#                task.max_positions(),
#                        model.max_positions(),
#
#        )
max_positions = (1024, 1024
                 )  # Hardcoded for the moment since the computation requires
# model object which will be created by model_parallel __call__
# Re-factor in a cleaner way

# Initialize dataloader
epoch_itr = task.get_batch_iterator(
예제 #15
0
def train_imagenet():
  print('==> Preparing data..')
  img_dim = get_model_property('img_dim')
  if FLAGS.fake_data:
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
              torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
        sample_count=1200000 // FLAGS.batch_size)
    test_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
              torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
        sample_count=50000 // FLAGS.batch_size)
  else:
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_dataset = torchvision.datasets.ImageFolder(
        os.path.join(FLAGS.datadir, 'train'),
        transforms.Compose([
            transforms.RandomResizedCrop(img_dim),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=FLAGS.batch_size,
        shuffle=True,
        num_workers=FLAGS.num_workers)
    test_dataset = torchvision.datasets.ImageFolder(
        os.path.join(FLAGS.datadir, 'val'),
        transforms.Compose([
            transforms.RandomResizedCrop(img_dim),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=FLAGS.batch_size,
        shuffle=True,
        num_workers=FLAGS.num_workers)

  torch.manual_seed(42)

  devices = xm.get_xla_supported_devices(max_devices=FLAGS.num_cores)
  # Pass [] as device_ids to run using the PyTorch/CPU engine.
  torchvision_model = get_model_property('model_fn')
  model_parallel = dp.DataParallel(torchvision_model, device_ids=devices)

  def train_loop_fn(model, loader, device, context):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        model.parameters(),
        lr=FLAGS.lr,
        momentum=FLAGS.momentum,
        weight_decay=5e-4)
    tracker = xm.RateTracker()
    for x, (data, target) in loader:
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS.batch_size)
      if x % FLAGS.log_steps == 0:
        print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(device, x, loss.item(),
                                                        tracker.rate()))

  def test_loop_fn(model, loader, device, context):
    total_samples = 0
    correct = 0
    for x, (data, target) in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    print('[{}] Accuracy={:.2f}%'.format(device,
                                         100.0 * correct / total_samples))
    return correct / total_samples

  accuracy = 0.0
  for epoch in range(1, FLAGS.num_epochs + 1):
    model_parallel(train_loop_fn, train_loader)
    accuracies = model_parallel(test_loop_fn, test_loader)
    accuracy = sum(accuracies) / len(devices)
    if FLAGS.metrics_debug:
      print(torch_xla._XLAC._xla_metrics_report())

  return accuracy * 100.0
예제 #16
0
파일: train_tpu.py 프로젝트: tweetyone/DCL
            raise Exception("no checkpoints to load")

        model_dict = model.state_dict()
        pretrained_dict = torch.load(resume)
        pretrained_dict = {
            k[7:]: v
            for k, v in pretrained_dict.items() if k[7:] in model_dict
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    print('Set cache dir', flush=True)
    time = datetime.datetime.now()

    num_cores = 8
    devices = (xm.get_xla_supported_devices(
        max_devices=num_cores) if num_cores != 0 else [])
    # Scale learning rate to num cores
    base_lr = args.base_lr * max(len(devices), 1)
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    model_parallel = dp.DataParallel(model, device_ids=devices)

    # optimizer prepare
    ignored_params1 = list(map(id, model.classifier.parameters()))
    ignored_params2 = list(map(id, model.classifier_swap.parameters()))
    ignored_params3 = list(map(id, model.Convmask.parameters()))

    ignored_params = ignored_params1 + ignored_params2 + ignored_params3
    print('the num of new layers:', len(ignored_params), flush=True)
    base_params = filter(lambda p: id(p) not in ignored_params,
                         model.parameters())
예제 #17
0
def train_imagenet():
    print('==> Preparing data..')
    img_dim = get_model_property('img_dim')
    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=1200000 // FLAGS.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            sampler=test_sampler,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    torchvision_model = get_model_property('model_fn')
    model_parallel = dp.DataParallel(torchvision_model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(model.parameters(),
                                           lr=FLAGS.lr,
                                           momentum=FLAGS.momentum,
                                           weight_decay=5e-4))
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = mean(accuracies)
        print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy))
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         epoch)
        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy
예제 #18
0
def train_cifar():
    print('==> Preparing data..')

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, 32, 32),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size)
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, 32, 32),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size)
    else:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                                train=True,
                                                download=True,
                                                transform=transform_train)
        train_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=FLAGS.batch_size,
            shuffle=True,
            num_workers=FLAGS.num_workers)

        testset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                               train=False,
                                               download=True,
                                               transform=transform_test)
        test_loader = torch.utils.data.DataLoader(
            testset,
            batch_size=FLAGS.batch_size,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    devices = xm.get_xla_supported_devices(max_devices=FLAGS.num_cores)
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    model_parallel = dp.DataParallel(ResNet18, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(),
                              lr=FLAGS.lr,
                              momentum=FLAGS.momentum,
                              weight_decay=5e-4)
        tracker = xm.RateTracker()

        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(
                    device, x, loss.item(), tracker.rate()))

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        print('[{}] Accuracy={:.2f}%'.format(device,
                                             100.0 * correct / total_samples))
        return correct / total_samples

    accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = sum(accuracies) / len(devices)
        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy * 100.0
예제 #19
0
def main_tpu(args):

  def log_step(step_type, device, step, tracker=None, metrics_debug=False):
    msg = '{}/ {}, device {}, step {}'.format(step_type, utils.now(), device,
                                              step)
    if tracker:
      rates = tracker.rate(), tracker.global_rate()
      msg += ', Rate={:.2f}, Global Rate={:.2f}'.format(*rates)
    return msg

  def train_loop_fn(model, loader, device, context):
    trainer = trainers[str(device)]
    stats = None
    tracker = xm.RateTracker()
    for i, samples in loader:
      if i and not (i % args.log_steps):
        print(
            log_step(
                'training',
                device,
                i,
                tracker=tracker,
                metrics_debug=args.metrics_debug))
      _log_output = trainer.train_step(samples)
      xm.optimizer_step(trainer.optimizer)
      tracker.add(len(samples) * args.max_sentences)  # n_batches * batch_size
    stats = fairseq_train.get_training_stats(trainer)
    return tracker, stats

  def valid_loop_fn(model, loader, device, context):
    trainer = trainers[str(device)]
    # reset validation loss meters
    for k in ['valid_loss', 'valid_nll_loss']:
      meter = trainer.get_meter(k)
      if meter is not None:
        meter.reset()
    extra_meters = collections.defaultdict(lambda: AverageMeter())
    for i, sample in loader:
      if not (i % args.log_steps):
        print(
            log_step(
                'validation',
                device,
                i,
                tracker=None,
                metrics_debug=args.metrics_debug))
      log_output = trainer.valid_step(sample)
      for k, v in log_output.items():
        if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
          continue
        extra_meters[k].update(v)
    stats = fairseq_train.get_valid_stats(trainer)
    for k, meter in extra_meters.items():
      stats[k] = meter.avg
    return stats

  def validate_subset(args, trainers, task, epoch_itr, subset):
    print('Validating the subset "{}"'.format(subset))
    # Initialize data iterator
    itr = task.get_batch_iterator(
        dataset=task.dataset(subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            list(trainers.values())[0].get_model().max_positions(),
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        seed=args.seed,
        num_workers=args.num_workers).next_epoch_itr(shuffle=False)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        prefix='valid on \'{}\' subset'.format(subset),
        no_progress_bar='simple')
    stats_per_device = model_parallel(valid_loop_fn, progress)
    valid_losses = [stats['loss'].avg for stats in stats_per_device]
    print('validation stats on subset "{}" - {}'.format(subset, utils.now()))
    for stats in stats_per_device:
      progress.print(stats, tag=subset, step=trainer.get_num_updates())
    return valid_losses

  def validate(args, trainers, task, epoch_itr, subsets):
    valid_losses = {
        subset: validate_subset(args, trainers, task, epoch_itr, subset)
        for subset in subsets
    }
    return valid_losses

  def initialize_loader_for_epoch(args, epoch_itr):
    if epoch_itr.epoch <= len(args.update_freq):
      update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
      update_freq = args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=False, shuffle=(epoch_itr.epoch >= args.curriculum))
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple')
    return progress

  def keep_training(lr, epoch_itr, trainers):
    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = min(trainer.get_lr() for trainer in trainers.values())
    n_updates = max(trainer.get_num_updates() for trainer in trainers.values())
    return ((lr > FLAGS.min_lr) and (epoch_itr.epoch < max_epoch) and
            (n_updates < max_update))

  xu.eprint('Args')
  for key, val in args.__dict__.items():
    xu.eprint('\t{} {}'.format(key, val))
  xu.eprint('---------')

  devices = xm.get_xla_supported_devices(max_devices=args.num_cores)
  task, trainers, model_parallel, epoch_itr, lr, valid_subsets = prepare_task(
      args, devices)

  train_meter = StopwatchMeter()
  train_meter.start()
  while keep_training(lr, epoch_itr, trainers):
    # TRAINING
    print('Epoch {} begin {}'.format(epoch_itr.epoch + 1, utils.now()))
    progress = initialize_loader_for_epoch(args, epoch_itr)
    out = model_parallel(train_loop_fn, progress)
    trackers, stats_ = zip(*out)
    print('Epoch {} Training stats:'.format(epoch_itr.epoch))
    for device, trainer in trainers.items():
      stats = fairseq_train.get_training_stats(trainer)
      print('device {}'.format(device))
      progress.print(stats, tag=device)
    print('Epoch {} Tracker Rates:'.format(epoch_itr.epoch))
    for tracker in trackers:
      rates = tracker.rate(), tracker.global_rate()
      print('\tRate={:.2f}, Global Rate={:.2f}'.format(*rates))
    print('Epoch {} end {}'.format(epoch_itr.epoch, utils.now()))
    if args.metrics_debug:
      print(torch_xla._XLAC._xla_metrics_report())

    # VALIDATION
    if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
      valid_losses = validate(args, trainers, task, epoch_itr, valid_subsets)

      # only use average first validation loss from the first device
      # to update the learning rate
      vloss = valid_losses[valid_subsets[0]][0]
      print('old learning rate: {}'.format(lr))
      lr = trainers[devices[0]].lr_step(epoch_itr.epoch, vloss)
      print('new learning rate: {}'.format(lr))

      # save checkpoint
      if epoch_itr.epoch % args.save_interval == 0:
        checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, vloss)

    if args.metrics_debug:
      print(torch_xla._XLAC._xla_metrics_report())

  train_meter.stop()
  print('| done training in {:.1f} seconds'.format(train_meter.sum))
예제 #20
0
 def test_get_real_xla_devices(self):
     devices = xm.get_xla_supported_devices()
     xla_devices = torch_xla._XLAC._xla_real_devices(devices)
     for device, xdevice in zip(devices, xla_devices):
         self.assertTrue(
             re.match(r'(CPU|GPU|TPU):\d+$', xdevice) is not None)
예제 #21
0
def train_cifar():
    print('==> Preparing data..')

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, 32, 32),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, 32, 32),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                                     train=True,
                                                     download=True,
                                                     transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                                    train=False,
                                                    download=True,
                                                    transform=transform_test)
        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            sampler=test_sampler,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    model = torchvision.models.resnet18 if FLAGS.use_torchvision else ResNet18
    model_parallel = dp.DataParallel(model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(model.parameters(),
                                           lr=FLAGS.lr,
                                           momentum=FLAGS.momentum,
                                           weight_decay=5e-4))
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = sum(accuracies) / len(accuracies)
        print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy))
        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy