Example #1
0
    def map_fn(self, index, train_dataset, dev_dataset, lr, epochs, batch_size, callbacks):
        if self.using_tpu is True:
            device = xm.xla_device()
        else:
            device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        train_loader = self.make_loader(train_dataset, batch_size, 'train')
        dev_loader = self.make_loader(dev_dataset, batch_size, 'dev')

        model = self.model.to(device)
        if self.using_tpu:
            opt = self.Opt([param for param in model.parameters() if param.requires_grad],
                           lr=lr*xm.xrt_world_size(), weight_decay=1e-4)  # hard coding
        else:
            opt = self.Opt([param for param in model.parameters() if param.requires_grad],
                           lr=lr, weight_decay=1e-4)  # hard coding

        loss_fn = self.Loss_fn(from_logits=True)

        callback_kwargs = {
            "model": model,
            "eval_dic": self.dev_eval,
        }

        for callback in callbacks:
            callback.train_init(**callback_kwargs)

        for epoch in range(epochs):
            if self.using_tpu:
                xm.rendezvous("training is starting!")
                if xm.is_master_ordinal():
                    print(f"\nepoch : {epoch+1} / {epochs}")
                now_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
            else:
                print(f"epoch : {epoch+1} / {epochs}")
                now_train_loader = train_loader
            model.train()
            for step, batch in enumerate(now_train_loader):
                logits, y, loss = self.compute_batch(model, batch, device, loss_fn, opt, phase='train')

                if self.using_tpu:
                    xm.rendezvous("update is starting!")
                    self.update(logits, y, loss, 'train', batch_size)
                    xm.rendezvous("update is ended!")
                    if xm.is_master_ordinal():
                        self.show_log(step*xm.xrt_world_size(), train_dataset, batch_size, 'train')
                else:
                    self.update(logits, y, loss, 'train', batch_size)
                    self.show_log(step, train_dataset, batch_size, 'train')

            if self.using_tpu:
                xm.rendezvous("batch is done!")
                if xm.is_master_ordinal():
                    print()
            else:
                print()

            model.eval()
            with torch.no_grad():
                if self.using_tpu:
                    now_dev_loader = pl.ParallelLoader(dev_loader, [device]).per_device_loader(device)
                else:
                    now_dev_loader = dev_loader
                for step, batch in enumerate(now_dev_loader):
                    logits, y, loss = self.compute_batch(model, batch, device, loss_fn, opt, phase='dev')

                    if self.using_tpu:
                        xm.rendezvous("update is starting!")
                        self.update(logits, y, loss, 'dev', batch_size)
                        xm.rendezvous("eval update is ended!")
                        if xm.is_master_ordinal():
                            self.show_log(step*xm.xrt_world_size(), dev_dataset, batch_size, 'dev')
                    else:
                        self.update(logits, y, loss, 'dev', batch_size)
                        self.show_log(step, dev_dataset, batch_size, 'dev')

                if self.using_tpu:
                    xm.rendezvous("batch is done!")
                    if xm.is_master_ordinal():
                        print()
                else:
                    print()
            self.on_epoch_end(callbacks)

        if self.using_tpu:
            xm.rendezvous("training is over!")
Example #2
0
def train(rank, args):
    print('enter train @ %s'%(rank), flush=True)
    args.rank = rank
    args.split = ''
    torch.manual_seed(42)
    save_fn = os.path.join(args.save_dir, 'checkpoint_final.pt')

    tokenizer = get_tokenizer(args)
    args.vocab_size = tokenizer._tokenizer.get_vocab_size() if not args.vocab_size else args.vocab_size
    
    train_dataset = get_dataset(args)
    
    batched_already = hasattr(train_dataset, '__getbatch__')

    if args.total_num_updates < 100:
        args.total_num_updates = len(train_dataset) * args.total_num_updates

    if args.warmup_updates < 1:
        args.warmup_updates = int(args.total_num_updates * args.warmup_updates)
    else:
        args.warmup_updates = int(args.warmup_updates)
    
    train_sampler = None
    if args.gpus:
        dist.init_process_group(
            'nccl', 
            rank=rank, 
            world_size=args.world_size
        )
        if args.gpus > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=args.gpus,
                rank=rank,
                shuffle=args.shuffle)

    else:
        rank = xm.get_ordinal()
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=rank,
                shuffle=args.shuffle)


    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size if not batched_already else None,
        sampler=train_sampler,
        pin_memory=True,
        shuffle=False,
        num_workers=args.num_workers)
        

    eval_loaders = []
    if args.eval_dir:
        for split in args.splits.split(','):
            split = split.strip()
            eval_sampler = None
            if args.gpus:
                if args.gpus > 1:
                    eval_sampler = torch.utils.data.distributed.DistributedSampler(
                        train_dataset,
                        num_replicas=args.gpus,
                        rank=rank,
                        shuffle=False)

            else:
                rank = xm.get_ordinal()
                if xm.xrt_world_size() > 1:
                    eval_sampler = torch.utils.data.distributed.DistributedSampler(
                        train_dataset,
                        num_replicas=xm.xrt_world_size(),
                        rank=rank,
                        shuffle=False)

            args.split = split
            eval_dataset = get_eval_dataset(args)
            eval_loader = torch.utils.data.DataLoader(
                eval_dataset,
                batch_size=args.batch_size if not batched_already else None,
                sampler=eval_sampler,
                pin_memory=True,
                shuffle=False,
                num_workers=args.num_workers)
            eval_loaders.append(eval_loader)

    if args.gpus:
        assert apex_enabled
        torch.cuda.set_device(rank)


        ##########################
        ##
        ##  Model Creation
        ##
        ##########################
        model = get_model(args, tokenizer)

        model.cuda(rank)

        device = torch.device('cuda:'+str(rank))

        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = apex.optimizers.FusedAdam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay,
                                 special_layer_wise_lr=args.special_layer_wise_lr,
                                 log = rank == 0,
                                 ),  

                                 # use this function to set extra optimizer arguments, 
                                 # see model_get_parameters
            betas=(0.9, 0.999), 
            eps=1e-6,
            lr=args.lr, 
            weight_decay=args.weight_decay
        )




        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        model = DDP(model)
        batches = train_loader

    else:
        assert tpu_enabled
        device = xm.xla_device()


        ##########################
        ##
        ##  Model Creation
        ##
        ##########################
        
        model = get_model(args, tokenizer)


        ##########################
        ##
        ##  For shared parameters, TPU requires modules to be tied after .to(device)
        ##  So we first find the shared parameters first
        ##
        ##########################

        shared_parameters = {e[0]: e[1:] for e in _catalog_shared_params(model)}

        model.to(device)
        
        do_share_parameters_again(model, shared_parameters, log = rank == 0)



        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = optim.Adam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay
                                 ),  

                                 # use this function to set extra optimizer arguments, 
                                 # see model_get_parameters
            lr=args.lr,
            weight_decay=args.weight_decay
        )


        writer = None
        if xm.is_master_ordinal():
            writer = test_utils.get_summary_writer(args.save_dir)
                
        xm.rendezvous("load_checkpoint")  # wait for all workers
        xm.mark_step()

        # tracker = xm.RateTracker()
        
        
        
    if args.restore_file:
        states = torch.load(args.restore_file, map_location=device)
        for k, v in list(states.items()):
            if k.startswith('module.'):
                del states[k]
                k = k[7:]
                states[k] = v
            if k.endswith('position_ids'):
                del states[k]
                states[k[:-12] + 'position_embeddings'] = v
                
        if args.gpus:
            states = {"module.%s"%k : v for k, v in states.items()}
        try:
            model.load_state_dict(states)
        except Exception as err:
            import traceback
            if rank == 0:
                traceback.print_exc()
            model.load_state_dict(states, strict=False)
            
        
    if rank == 0:
        if not os.path.exists(os.path.dirname(save_fn)):
            try:
                os.makedirs(os.path.dirname(save_fn))
            except OSError as exc: # Guard against race condition
                if exc.errno != errno.EEXIST:
                    raise
        if args.gpus:
            torch.save(model.state_dict(), save_fn )
        else:
            xm.save(model.state_dict(), save_fn )
        
    model.train()

    if args.anomaly_detection and rank == 0:
        torch.set_anomaly_enabled(True)

    ##########################
    ##
    ##  Init LR Scheduler
    ##
    ##########################
    
    if not batched_already:
        args.total_num_updates = args.total_num_updates // args.batch_size
        args.warmup_updates = args.total_num_updates // args.batch_size
        
        
    args.total_num_updates = args.total_num_updates // args.world_size
    args.warmup_updates = args.total_num_updates // args.world_size

    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=args.warmup_updates, 
        num_training_steps=args.total_num_updates, 
    )

    step_i = 0

    err = None
    tb = None
    #tb = SummaryWriter()
    try:
        if rank == 0:
            pbar = tqdm(total=args.total_num_updates, file=sys.stdout)
        while step_i < args.total_num_updates:
            if not args.gpus:
                batches = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
                
            n_samples = len(batches)
                
            for sample in batches:
                step_i += 1
                if step_i > args.total_num_updates:
                    break

                report_step = step_i % args.log_interval == 0

                while True: # the loop only for apex Gradient Overflow
                    optimizer.zero_grad()

                    total_loss, log = get_loss(
                        model, 
                        sample, 
                        args=args, 
                        device=device, 
                        gpus=args.gpus, 
                        report=report_step
                    )

                    if args.gpus:
                        default_optimizer_step = optimizer.step

                        with amp.scale_loss(total_loss, optimizer) as scaled_loss:
                            scaled_loss.backward()

                        # If Amp detects an overflow, it patches optimizer.step.  In other words, if optimizer.step
                        # was left unpatched, there was no overflow, and we don't need to replay.
                        if optimizer.step is default_optimizer_step:
                            optimizer.step()
                            break

                        optimizer.step() # If an overflow was detected, "optimizer.step" is the patched call, which does 
                                         # nothing but restore optimizer.step to default_optimizer_step.
                        if rank == 0:
                            print("Overflowed, reducing loss scale and replaying batch.", flush=True)
                        
                    else:
                        total_loss.backward()
                        xm.optimizer_step(optimizer)
                        xm.mark_step()

                        break



                scheduler.step()

                if report_step:
                    if 'loss' not in log:
                        log['loss'] = total_loss

                    # tb.add_scalar("Loss", total_loss, step_i)

                    for k, v in log.items():
                        try:
                            dist.all_reduce(v, op=dist.reduce_op.SUM)
                            log[k] = float(v)
                        except Exception as e:
                            print(v, e)
                            pass
                        
                    if args.gpus:
                        if rank == 0:
                            pbar.set_description(format_log(log, log_formatter, tb, step_i))
                    else:
                        xm.add_step_closure(_train_update, args=(log, log_formatter, tb, step_i))

                    if args.report_metrics:
                        xm.master_print(met.metrics_report())

                
                if rank == 0:
                    pbar.update(1)

        if rank == 0:
            pbar.close()
        if eval_loaders:
            model.half()
            model.eval()
            model.cuda()
            for k, v in model.named_parameters():
                v.requires_grad =False

                
            for split, eval_loader in zip(args.splits.split(','), eval_loaders):
                batches = eval_loader
                if rank == 0:
                    eval_length = len(batches)
                    if not batched_already:
                        eval_length = eval_length // args.batch_size

                    eval_length = eval_length // args.world_size

                    pbar = tqdm(total=eval_length, file=sys.stdout)
                
                if not args.gpus:
                    batches = pl.ParallelLoader(eval_loader, [device]).per_device_loader(device)
                with torch.no_grad():
                    record = OrderedDict()

                    for sample in batches:
                        evaluate(
                            model, 
                            sample, 
                            args=args, 
                            device=device, 
                            record=record,
                            gpus=args.gpus, 
                            report=False
                        )
                        if rank == 0:
                            pbar.update(1)

                    for k, v in record.items():
                        try:
                            def handle_reduce(v):
                                if len(v.shape) == 0:
                                    dist.all_reduce(v, op=dist.reduce_op.SUM)
                                else:
                                    L = [torch.ones_like(v) for _ in range(dist.get_world_size())]
                                    dist.all_gather(L, v)
                                    v = torch.car(L, dim=0)
                                return v
                            if isinstance(v, list):
                                v = [handle_reduce(e) for e in v]
                            else:
                                v = handle_reduce(v)
                            record[k] = float(v)
                        except Exception as e:
                            pass

                    post_evaluate(record, args=args)

                import json

                if rank == 0:
                    print('',flush=True)
                    print('Test result for %s'%split, flush=True)
                    print(json.dumps(record, indent=2),flush=True)
                    print('',flush=True)


    except Exception as _err:
        err = _err
    finally:
        folder = os.path.split(os.path.abspath(save_fn))[0]
        os.makedirs(folder, exist_ok=True)
        if rank == 0:
            print("Saving to %s"%save_fn)
            if args.gpus:
                torch.save(model.state_dict(), save_fn )
                if err:
                    raise err
            else:
                xm.save(model.state_dict(), save_fn )
                if err:
                    raise err
            print("Saved to %s"%save_fn)
def train_mnist(flags, **kwargs):
    torch.manual_seed(1)

    if flags.fake_data:
        train_loader = xu.SampleGenerator(
            data=(
                torch.zeros(flags.batch_size, 1, 28, 28),
                torch.zeros(flags.batch_size, dtype=torch.int64),
            ),
            sample_count=60000 // flags.batch_size // xm.xrt_world_size(),
        )
        test_loader = xu.SampleGenerator(
            data=(
                torch.zeros(flags.batch_size, 1, 28, 28),
                torch.zeros(flags.batch_size, dtype=torch.int64),
            ),
            sample_count=10000 // flags.batch_size // xm.xrt_world_size(),
        )
    else:
        train_dataset = datasets.MNIST(
            os.path.join(flags.datadir, str(xm.get_ordinal())),
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]),
        )
        test_dataset = datasets.MNIST(
            os.path.join(flags.datadir, str(xm.get_ordinal())),
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]),
        )
        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=flags.batch_size,
            sampler=train_sampler,
            drop_last=flags.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=flags.num_workers,
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=flags.batch_size,
            drop_last=flags.drop_last,
            shuffle=False,
            num_workers=flags.num_workers,
        )

    # Scale learning rate to num cores
    lr = flags.lr * xm.xrt_world_size()

    device = xm.xla_device()
    model = MNIST().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(flags.logdir)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum)
    loss_fn = nn.NLLLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(flags.batch_size)
            if step % flags.log_steps == 0:
                xm.add_step_closure(_train_update,
                                    args=(device, step, loss, tracker, writer))

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct.item() / total_samples
        accuracy = xm.mesh_reduce("test_accuracy", accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)
    accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, flags.num_epochs + 1):
        xm.master_print("Epoch {} train begin {}".format(
            epoch, test_utils.now()))
        train_loop_fn(train_device_loader)
        xm.master_print("Epoch {} train end {}".format(epoch,
                                                       test_utils.now()))

        accuracy = test_loop_fn(test_device_loader)
        xm.master_print("Epoch {} test end {}, Accuracy={:.2f}".format(
            epoch, test_utils.now(), accuracy))
        max_accuracy = max(accuracy, max_accuracy)
        test_utils.write_to_summary(writer,
                                    epoch,
                                    dict_to_write={"Accuracy/test": accuracy},
                                    write_xla_metrics=True)
        if flags.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print("Max Accuracy: {:.2f}%".format(max_accuracy))
    return max_accuracy
Example #4
0
def train_imagenet():
  print('==> Preparing data..')
  img_dim = get_model_property('img_dim')
  if FLAGS.fake_data:
    train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
              torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
        sample_count=train_dataset_len // FLAGS.batch_size //
        xm.xrt_world_size())
    test_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
              torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
        sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
  else:
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_dataset = torchvision.datasets.ImageFolder(
        os.path.join(FLAGS.datadir, 'train'),
        transforms.Compose([
            transforms.RandomResizedCrop(img_dim),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_dataset_len = len(train_dataset.imgs)
    resize_dim = max(img_dim, 256)
    test_dataset = torchvision.datasets.ImageFolder(
        os.path.join(FLAGS.datadir, 'val'),
        # Matches Torchvision's eval transforms except Torchvision uses size
        # 256 resize for all models both here and in the train loader. Their
        # version crashes during training on 299x299 images, e.g. inception.
        transforms.Compose([
            transforms.Resize(resize_dim),
            transforms.CenterCrop(img_dim),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = None
    if xm.xrt_world_size() > 1:
      train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=FLAGS.batch_size,
        sampler=train_sampler,
        shuffle=False if train_sampler else True,
        num_workers=FLAGS.num_workers)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=FLAGS.test_set_batch_size,
        shuffle=False,
        num_workers=FLAGS.num_workers)

  torch.manual_seed(42)

  device = xm.xla_device()
  model = get_model_property('model_fn')().to(device)
  writer = None
  if FLAGS.logdir and xm.is_master_ordinal():
    writer = SummaryWriter(log_dir=FLAGS.logdir)
  optimizer = optim.SGD(
      model.parameters(),
      lr=FLAGS.lr,
      momentum=FLAGS.momentum,
      weight_decay=1e-4)
  num_training_steps_per_epoch = train_dataset_len // (
      FLAGS.batch_size * xm.xrt_world_size())
  lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
      optimizer,
      scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
      scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
      scheduler_divide_every_n_epochs=getattr(
          FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
      num_steps_per_epoch=num_training_steps_per_epoch,
      summary_writer=writer)
  loss_fn = nn.CrossEntropyLoss()

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    model.train()
    for x, (data, target) in loader:
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS.batch_size)
      if lr_scheduler:
        lr_scheduler.step()
      if x % FLAGS.log_steps == 0:
        test_utils.print_training_update(device, x, loss.item(), tracker.rate(),
                                         tracker.global_rate())

  def test_loop_fn(loader):
    total_samples = 0
    correct = 0
    model.eval()
    for x, (data, target) in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct / total_samples
    test_utils.print_test_update(device, accuracy)
    return accuracy

  accuracy = 0.0
  for epoch in range(1, FLAGS.num_epochs + 1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    train_loop_fn(para_loader.per_device_loader(device))
    if xm.is_master_ordinal():
      print("Finished training epoch {}".format(epoch))

    para_loader = pl.ParallelLoader(test_loader, [device])
    accuracy = test_loop_fn(para_loader.per_device_loader(device))
    test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch)

    if FLAGS.metrics_debug:
      print(met.metrics_report())

  return accuracy
 def is_master_ordinal(self):
     return xm.is_master_ordinal()
def train_imagenet():
    torch.manual_seed(42)

    device = xm.xla_device()
    # model = get_model_property('model_fn')().to(device)
    model = create_model(
        FLAGS.model,
        pretrained=FLAGS.pretrained,
        num_classes=FLAGS.num_classes,
        drop_rate=FLAGS.drop,
        global_pool=FLAGS.gp,
        bn_tf=FLAGS.bn_tf,
        bn_momentum=FLAGS.bn_momentum,
        bn_eps=FLAGS.bn_eps,
        drop_connect_rate=0.2,
        checkpoint_path=FLAGS.initial_checkpoint,
        args = FLAGS).to(device)
    model_ema=None
    if FLAGS.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        # import pdb; pdb.set_trace()
        model_e = create_model(
            FLAGS.model,
            pretrained=FLAGS.pretrained,
            num_classes=FLAGS.num_classes,
            drop_rate=FLAGS.drop,
            global_pool=FLAGS.gp,
            bn_tf=FLAGS.bn_tf,
            bn_momentum=FLAGS.bn_momentum,
            bn_eps=FLAGS.bn_eps,
            drop_connect_rate=0.2,
            checkpoint_path=FLAGS.initial_checkpoint,
            args = FLAGS).to(device)
        model_ema = ModelEma(
            model_e,
            decay=FLAGS.model_ema_decay,
            device='cpu' if FLAGS.model_ema_force_cpu else '',
            resume=FLAGS.resume)
    print('==> Preparing data..')
    img_dim = 224
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                    torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                    torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    # else:
    #     normalize = transforms.Normalize(
    #         mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    #     train_dataset = torchvision.datasets.ImageFolder(
    #         os.path.join(FLAGS.data, 'train'),
    #         transforms.Compose([
    #             transforms.RandomResizedCrop(img_dim),
    #             transforms.RandomHorizontalFlip(),
    #             transforms.ToTensor(),
    #             normalize,
    #         ]))
    #     train_dataset_len = len(train_dataset.imgs)
    #     resize_dim = max(img_dim, 256)
    #     test_dataset = torchvision.datasets.ImageFolder(
    #         os.path.join(FLAGS.data, 'val'),
    #         # Matches Torchvision's eval transforms except Torchvision uses size
    #         # 256 resize for all models both here and in the train loader. Their
    #         # version crashes during training on 299x299 images, e.g. inception.
    #         transforms.Compose([
    #             transforms.Resize(resize_dim),
    #             transforms.CenterCrop(img_dim),
    #             transforms.ToTensor(),
    #             normalize,
    #         ]))

    #     train_sampler = None
    #     if xm.xrt_world_size() > 1:
    #         train_sampler = torch.utils.data.distributed.DistributedSampler(
    #             train_dataset,
    #             num_replicas=xm.xrt_world_size(),
    #             rank=xm.get_ordinal(),
    #             shuffle=True)
    #     train_loader = torch.utils.data.DataLoader(
    #         train_dataset,
    #         batch_size=FLAGS.batch_size,
    #         sampler=train_sampler,
    #         shuffle=False if train_sampler else True,
    #         num_workers=FLAGS.workers)
    #     test_loader = torch.utils.data.DataLoader(
    #         test_dataset,
    #         batch_size=FLAGS.batch_size,
    #         shuffle=False,
    #         num_workers=FLAGS.workers)
    else:
        train_dir = os.path.join(FLAGS.data, 'train')
        data_config = resolve_data_config(model, FLAGS, verbose=FLAGS.local_rank == 0)
        dataset_train = Dataset(train_dir)

        collate_fn = None
        if not FLAGS.no_prefetcher and FLAGS.mixup > 0:
            collate_fn = FastCollateMixup(FLAGS.mixup, FLAGS.smoothing, FLAGS.num_classes)
        train_loader = create_loader(
            dataset_train,
            input_size=data_config['input_size'],
            batch_size=FLAGS.batch_size,
            is_training=True,
            use_prefetcher=not FLAGS.no_prefetcher,
            rand_erase_prob=FLAGS.reprob,
            rand_erase_mode=FLAGS.remode,
            interpolation='bicubic',  # FIXME cleanly resolve this? data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=FLAGS.workers,
            distributed=FLAGS.distributed,
            collate_fn=collate_fn,
            use_auto_aug=FLAGS.auto_augment,
            use_mixcut=FLAGS.mixcut,
        )

        eval_dir = os.path.join(FLAGS.data, 'val')
        train_dataset_len = len(train_loader)
        if not os.path.isdir(eval_dir):
            logging.error('Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
        dataset_eval = Dataset(eval_dir)

        test_loader = create_loader(
            dataset_eval,
            input_size=data_config['input_size'],
            batch_size = FLAGS.batch_size,
            is_training=False,
            use_prefetcher=FLAGS.prefetcher,
            interpolation=data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=FLAGS.workers,
            distributed=FLAGS.distributed,
        )


    writer = None
    start_epoch = 0
    if FLAGS.output and xm.is_master_ordinal():
        writer = SummaryWriter(log_dir=FLAGS.output)
    optimizer = create_optimizer(flags, model)
    lr_scheduler, num_epochs = create_scheduler(flags, optimizer)
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
    # optimizer = optim.SGD(
    #     model.parameters(),
    #     lr=FLAGS.lr,
    #     momentum=FLAGS.momentum,
    #     weight_decay=5e-4)
    num_training_steps_per_epoch = train_dataset_len // (
        FLAGS.batch_size * xm.xrt_world_size())
        
    lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
        optimizer,
        scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
        scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
        scheduler_divide_every_n_epochs=getattr(
            FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
        num_steps_per_epoch=num_training_steps_per_epoch,
        summary_writer=writer)
    train_loss_fn = LabelSmoothingCrossEntropy(smoothing=flags.smoothing)
    validate_loss_fn = nn.CrossEntropyLoss()
    # loss_fn = nn.CrossEntropyLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = train_loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if model_ema is not None:
                model_ema.update(model)
            if lr_scheduler:
                lr_scheduler.step()
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(), tracker.rate(),
                                            tracker.global_rate())

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy
    def test_loop_fn_ema(loader):
            total_samples = 0
            correct = 0
            model_ema.eval()
            for x, (data, target) in loader:
                output = model_ema(data)
                pred = output.max(1, keepdim=True)[1]
                correct += pred.eq(target.view_as(pred)).sum().item()
                total_samples += data.size()[0]

            accuracy = 100.0 * correct / total_samples
            test_utils.print_test_update(device, accuracy)
            return accuracy
    accuracy = 0.0
    for epoch in range(1, FLAGS.epochs + 1):
        para_loader = dp.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))

        para_loader = dp.ParallelLoader(test_loader, [device])
        accuracy = test_loop_fn(para_loader.per_device_loader(device))
        print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
        if model_ema is not None:
            accuracy = test_loop_fn_ema(para_loader.per_device_loader(device))
            print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch)

        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy
def train(args, train_loader, model, device, optimizer, scheduler, epoch, f,
          max_seq_len):
    total_loss = AverageMeter()
    losses1 = AverageMeter()  # start
    losses2 = AverageMeter()  # end
    accuracies1 = AverageMeter()  # start
    accuracies2 = AverageMeter()  # end

    model.train()

    tr_loss = 0.0

    t = tqdm(train_loader, disable=not xm.is_master_ordinal())
    for step, d in enumerate(t):

        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        token_type_ids = d["token_type_ids"].to(device)
        start_position = d["start_position"].to(device)
        end_position = d["end_position"].to(device)

        sentiment_label = d["sentiment_label"].to(device)

        model.zero_grad()

        logits1, logits2 = model(input_ids=input_ids,
                                 attention_mask=attention_mask,
                                 token_type_ids=token_type_ids,
                                 position_ids=None,
                                 head_mask=None)

        #y_true = (start_position, end_position)
        loss1, loss2 = loss_fn((logits1, logits2),
                               (start_position, end_position))
        #loss3 = loss_fn_sentiment(logits3, sentiment_label)
        loss = loss1 + loss2

        #max_seq_len = 256
        #loss = Closs.loss_fn(logits1, logits2, start_position, end_position,device, max_seq_len)

        acc1, n_position1 = get_position_accuracy(logits1, start_position)
        acc2, n_position2 = get_position_accuracy(logits2, end_position)

        total_loss.update(loss.item(), n_position1)
        losses1.update(loss1.item(), n_position1)
        losses2.update(loss2.item(), n_position2)
        accuracies1.update(acc1, n_position1)
        accuracies2.update(acc2, n_position2)

        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        loss.backward()

        tr_loss += loss.item()
        if (step + 1) % args.gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.max_grad_norm)

            xm.optimizer_step(optimizer)
            scheduler.step()
            model.zero_grad()

        print_loss = xm.mesh_reduce("loss_reduce", total_loss.avg, reduce_fn)
        print_acc1 = xm.mesh_reduce("acc1_reduce", accuracies1.avg, reduce_fn)
        print_acc2 = xm.mesh_reduce("acc2_reduce", accuracies2.avg, reduce_fn)
        t.set_description(
            f"Train E:{epoch+1} - Loss:{print_loss:0.2f} - acc1:{print_acc1:0.2f} - acc2:{print_acc2:0.2f}"
        )

    log_ = f"Epoch : {epoch+1} - train_loss : {total_loss.avg} - \n \
    train_loss1 : {losses1.avg} - train_loss2 : {losses2.avg} - \n \
    train_acc1 : {accuracies1.avg} - train_acc2 : {accuracies2.avg}"

    f.write(log_ + "\n\n")
    f.flush()

    return total_loss.avg
Example #8
0
def train_mnist():
    torch.manual_seed(1)

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=60000 // FLAGS.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        train_dataset = datasets.MNIST(os.path.join(FLAGS.datadir,
                                                    str(xm.get_ordinal())),
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        test_dataset = datasets.MNIST(os.path.join(FLAGS.datadir,
                                                   str(xm.get_ordinal())),
                                      train=False,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    # Scale learning rate to num cores
    lr = FLAGS.lr * xm.xrt_world_size()

    device = xm.xla_device()
    model = MNIST().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum)
    loss_fn = nn.NLLLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                xm.add_step_closure(_train_update,
                                    args=(device, x, loss, tracker))

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print('Finished training epoch {}'.format(epoch))

        para_loader = pl.ParallelLoader(test_loader, [device])
        accuracy = test_loop_fn(para_loader.per_device_loader(device))
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         epoch)
        if FLAGS.metrics_debug:
            print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    return accuracy
Example #9
0
def multi_core(index, flags):
    torch.manual_seed(flags['seed'])
    batch_size = 4
    device = xm.xla_device()
    max_epoch = 1

    #Only download X, Y on one process
    if not xm.is_master_ordinal():
        xm.rendezvous('download_only_once')

    train_dataset = MRIDataset(mode='train')
    valid_dataset = MRIDataset(mode='validation')
    val_loss = []

    if xm.is_master_ordinal():
        xm.rendezvous('download_only_once')

    # XLA distributed sampler for more than 1 TPU
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=flags['batch_size'],
        sampler=train_sampler,
        num_workers=flags['num_workers'],
        drop_last=True)
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=flags['batch_size'],
        sampler=valid_sampler,
        num_workers=flags['num_workers'],
        drop_last=True)
    model = NeuroImageModel().to(device).train()

    criterion = torch.nn.L1Loss(reduction='mean')
    # optimizer = torch.optim.Adam(model.parameters(), lr=config.hyper_params['lr'], betas=config.hyper_params['betas'], eps=1e-08)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=config.hyper_params['lr'],
                                momentum=0.9,
                                nesterov=True)
    train_start = time.time()
    file = open('loss', 'w')
    average = 0
    for epoch in range(flags['num_epochs']):
        # Training time
        train_pl_loader = pl.ParallelLoader(train_loader,
                                            [device]).per_device_loader(device)
        start = time.time()
        average = 0
        count = 0
        for batch_num, batch in enumerate(train_pl_loader):
            optimizer.zero_grad()
            print("Process", index, "saving scan")
            scans = batch['scans']
            data = batch['data']
            targets = batch['targets']
            output = model(data, scans)
            del scans
            loss = criterion(output, targets)
            del targets
            loss.backward()
            xm.master_print(
                f'training: index: {batch_num} loss: {loss.item()}')
            count = count + 1
            average = average + loss.item()
            xm.optimizer_step(optimizer, barrier=True)
        print(
            f'Training loss for epoch: {epoch} average of: {average/count} with count {count}'
        )
        file.write(
            f'Training loss for epoch: {epoch} average of: {average/count} with count {count}'
        )
        # average = 0
        # count = 0
        del loss
        # with torch.no_grad():
        #     valid_pl_loader = pl.ParallelLoader(valid_loader, [device]).per_device_loader(device)
        #     model.eval()
        #     for batch_num, batch in enumerate(valid_pl_loader):
        #         scans = batch['scans']
        #         fnc = batch['fnc']
        #         sbm = batch['sbm']
        #         targets = batch['targets']
        #         output = model(fnc, sbm, scans)
        #         del scans
        #         validation_loss = criterion(output, targets)
        #         del targets
        #         xm.master_print(f'validation: index: {batch_num} loss: {validation_loss.item()}')
        #         count = count + 1
        #         average = average + validation_loss.item()
        #         val_loss.append(validation_loss)
        #     del valid_pl_loader
    elapsed_train_time = time.time() - train_start
    print("Process", index, "finished training. Train time was:",
          elapsed_train_time)
    torch.save(
        f'epoch: {epoch}, state_dict: {model.state_dict()}, validation loss: {val_loss}, optimizer: {optimizer.state_dict()}',
        f'{config.hyper_params["model_save_path"]}/validation_loss_{time.time()}.txt'
    )
Example #10
0
def run(config):
    def len_parallelloader(self):
        return len(self._loader._loader)
    pl.PerDeviceLoader.__len__ = len_parallelloader

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        xm.master_print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different
    # files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    xm.master_print('Experiment name is %s' % experiment_name)

    device = xm.xla_device(devkind='TPU')

    # Next, build the model
    G = model.Generator(**config)
    D = model.Discriminator(**config)

    # If using EMA, prepare it
    if config['ema']:
        xm.master_print(
            'Preparing EMA for G with decay of {}'.format(
                config['ema_decay']))
        G_ema = model.Generator(**{**config, 'skip_init': True,
                                   'no_optim': True})
    else:
        xm.master_print('Not using ema...')
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        xm.master_print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        xm.master_print('Casting D to fp16...')
        D = D.half()

    # Prepare state dict, which holds things like itr #
    state_dict = {'itr': 0, 'save_num': 0, 'save_best_num': 0,
                  'best_IS': 0, 'best_FID': 999999, 'config': config}

    # If loading from a pre-trained model, load weights
    if config['resume']:
        xm.master_print('Loading weights...')
        utils.load_weights(
            G,
            D,
            state_dict,
            config['weights_root'],
            experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # move everything to TPU
    G = G.to(device)
    D = D.to(device)

    G.optim = optim.Adam(params=G.parameters(), lr=G.lr,
                         betas=(G.B1, G.B2), weight_decay=0,
                         eps=G.adam_eps)
    D.optim = optim.Adam(params=D.parameters(), lr=D.lr,
                         betas=(D.B1, D.B2), weight_decay=0,
                         eps=D.adam_eps)

    # for key, val in G.optim.state.items():
    #  G.optim.state[key]['exp_avg'] = G.optim.state[key]['exp_avg'].to(device)
    #  G.optim.state[key]['exp_avg_sq'] = G.optim.state[key]['exp_avg_sq'].to(device)

    # for key, val in D.optim.state.items():
    #  D.optim.state[key]['exp_avg'] = D.optim.state[key]['exp_avg'].to(device)
    #  D.optim.state[key]['exp_avg_sq'] = D.optim.state[key]['exp_avg_sq'].to(device)

    if config['ema']:
        G_ema = G_ema.to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])

    # Consider automatically reducing SN_eps?
    GD = model.G_D(G, D)
    xm.master_print(G)
    xm.master_print(D)
    xm.master_print('Number of params in G: {} D: {}'.format(
        *[sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]]))

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    xm.master_print(
        'Test Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    xm.master_print(
        'Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])

    if xm.is_master_ordinal():
            # Write metadata
        utils.write_metadata(
            config['logs_root'],
            experiment_name,
            config,
            state_dict)

    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps']
                    * config['num_D_accumulations'])
    xm.master_print('Preparing data...')
    loader = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                       'start_itr': state_dict['itr']})

    # Prepare inception metrics: FID and IS
    xm.master_print('Preparing metrics...')

    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'],
        no_inception=config['no_inception'],
        no_fid=config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])

    def sample(): return utils.prepare_z_y(G_batch_size, G.dim_z,
                                           config['n_classes'], device=device,
                                           fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout
    # training
    fixed_z, fixed_y = sample()

    train = train_fns.GAN_training_function(G, D, GD, sample, ema, state_dict,
                                            config)

    xm.master_print('Beginning training...')

    if xm.is_master_ordinal():
        pbar = tqdm(total=config['total_steps'])
        pbar.n = state_dict['itr']
        pbar.refresh()

    xm.rendezvous('training_starts')
    while (state_dict['itr'] < config['total_steps']):
        pl_loader = pl.ParallelLoader(
            loader, [device]).per_device_loader(device)

        for i, (x, y) in enumerate(pl_loader):
            if xm.is_master_ordinal():
                # Increment the iteration counter
                pbar.update(1)

            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter
            # much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()

            xm.rendezvous('data_collection')
            metrics = train(x, y)

            # train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if ((config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval']))) :
                if xm.is_master_ordinal():
                    train_log.log(itr=int(state_dict['itr']),
                        **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})
                xm.rendezvous('Log SVs.')

            # Save weights and copies as configured at specified interval
            if (not (state_dict['itr'] % config['save_every'])):
                if config['G_eval_mode']:
                    xm.master_print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(
                    G,
                    D,
                    G_ema,
                    sample,
                    fixed_z,
                    fixed_y,
                    state_dict,
                    config,
                    experiment_name)

            # Test every specified interval
            if (not (state_dict['itr'] % config['test_every'])):

                which_G = G_ema if config['ema'] and config['use_ema'] else G
                if config['G_eval_mode']:
                    xm.master_print('Switchin G to eval mode...')
                    which_G.eval()

                def G_sample():
                    z, y = sample()
                    return which_G(z, which_G.shared(y))

                train_fns.test(
                    G,
                    D,
                    G_ema,
                    sample,
                    state_dict,
                    config,
                    G_sample,
                    get_inception_metrics,
                    experiment_name,
                    test_log)
            
            # Debug : Message print
            # if True:
            #     xm.master_print(met.metrics_report())

            if state_dict['itr'] >= config['total_steps']:
                break
Example #11
0
def _mp_fn(index, args):
    torch.set_default_tensor_type('torch.FloatTensor')
    distributed_utils.suppress_output(xm.is_master_ordinal())
    main_tpu(args)
Example #12
0
def train_model(model,
                criterion,
                optimizer,
                scheduler,
                i,
                class_names,
                metric_targets,
                metric_types,
                dataset_types,
                data_loaders,
                dataset_sizes,
                device,
                cfg,
                num_epochs=25,
                batch_size=4,
                patience=5,
                lambda_u=1.0,
                threshold=0.95,
                purpose='baseline',
                is_early=True):
    '''Train the model.

    Args:
        model (obj): the model which will be trained
        criterion (obj): the loss function (e.g. cross entropy)
        optimizer (obj): the optimizer (e.g. Adam)
        scheduler (obj): the learning scheduler (e.g. Step decay)
        i (int): the number indicating which model it is
        class_names (dict): class names for images (e.g. {0: 'covid-19', 1: 'pneumonia', 2: 'normal'})
        metric_targets (list): metric targets to calculate performance metrics of the model
                               (e.g. ['all', 'covid-19', 'pneumonia', 'normal'])
        metric_types (list): the performance metrics of the model (e.g. Accuracy, F1-Score and so on)
        dataset_types (list): dataset types for train and test (e.g. ['train', 'test'] or ['train', 'val', 'test])
        data_loaders (list): data loaders applied transformations, the batch size and so on
        dataset_sizes (dict): sizes of train and test datasets
        device (obj): the device where the model will be trained (e.g. cpu or gpu)
        num_epochs (int): the number of epochs
        batch_size (int): the batch size
        patience (int): the number of patience times for early stopping
        lambda_u (float): the ratio of reflect unlabeled loss
        threshold (float): the treshold for predicted results for unlabeled data
        purpose (str): the purpose of the model

    Returns:
        model (obj): the model which was trained
        best_metrics (dict): the results of the best performance metrics after training the model
    '''
    # Import XLA libraries for using TPUs
    if cfg['use_tpu']:
        import torch_xla.core.xla_model as xm
        import torch_xla.distributed.parallel_loader as pl

    since = time.time()
    if is_early:
        early_stopping = EarlyStopping(patience=patience, verbose=True)
    best_metrics = {m_type: defaultdict(float) for m_type in metric_types}
    epoch_metrics_list = []

    print(f'{"-" * 20}\nModel {i + 1}\n{"-" * 20}\n')
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and test phase
        for phase in dataset_types:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                if cfg['use_tpu'] and not xm.is_master_ordinal():
                    continue
                model.eval()  # Set model to evaluate mode

            epoch_loss = 0.0
            batch_metrics = {
                'tp': defaultdict(int),
                'size': defaultdict(int),
                'fp': defaultdict(int),
                'fn': defaultdict(int)
            }
            mask_ratio = []  # just for fixmatch

            # Create a pareallel loader
            if cfg['use_tpu'] and phase == 'train':
                # data_loaders[phase].sampler.set_epoch(epoch)

                final_data_loader = pl.ParallelLoader(
                    data_loaders[phase], [device]).per_device_loader(device)
            else:
                final_data_loader = data_loaders[phase]

            # Iterate over data.
            for batch in final_data_loader:
                size = batch['img_lb'].size(0)
                # Load batches
                if purpose != 'baseline' and phase == 'train':
                    inputs = torch.cat([
                        batch['img_lb'], batch['img_ulb'], batch['img_ulb_wa']
                    ], 0).to(device)
                else:
                    inputs = batch['img_lb'].to(device)
                labels = batch['label'].to(device)
                del batch
                # zero the parameter gradients
                optimizer.zero_grad()

                # Forward the model
                with torch.set_grad_enabled(phase == 'train'):
                    # Calculate labeled loss
                    outputs = model(inputs)
                    if purpose != 'baseline' and phase == 'train':
                        outputs_lb = outputs[:size]
                        outputs_ulb, outputs_ulb_wa = outputs[size:].chunk(2)
                        del outputs
                    else:
                        outputs_lb = outputs

                    _, preds = torch.max(outputs_lb, 1)
                    loss = loss_lb = criterion(outputs_lb, labels)

                    # Calculate unlabeled loss for FixMatch
                    if purpose != 'baseline' and phase == 'train':
                        probs_ulb = torch.softmax(outputs_ulb, dim=-1)
                        probs_ulb, preds_ulb = torch.max(probs_ulb, 1)
                        mask = probs_ulb.ge(threshold).float()
                        if cfg['sharpening']:  # using sharpening
                            # https://github.com/LeeDoYup/FixMatch-pytorch/blob/0e0b492f1cb110a43c765c55105b5f94c13f45fd/models/fixmatch/fixmatch_utils.py#L35
                            # sharpen_output = torch.softmax(outputs_ulb/cfg['temperature'], dim=-1)
                            # log_pred = F.log_softmax(outputs_ulb_wa, dim=-1)
                            # loss_sharpen = (torch.sum(-sharpen_output*log_pred, dim=1) * mask).mean()
                            if cfg['focal_loss']:
                                sharpen_probs_ulb = torch.softmax(
                                    outputs_ulb / cfg['temperature'], dim=-1)
                                log_pred = F.log_softmax(outputs_ulb_wa,
                                                         dim=-1)
                                loss_ulb = torch.sum(-sharpen_probs_ulb *
                                                     log_pred,
                                                     dim=1)
                                pt = torch.exp(-loss_ulb)
                                loss_ulb = (((1 - pt)**cfg['gamma'] *
                                             loss_ulb)).mean()
                            else:
                                sharpen_label = torch.softmax(
                                    outputs_ulb / cfg['temperature'], dim=-1)
                                log_pred = F.log_softmax(outputs_ulb_wa,
                                                         dim=-1)
                                loss_ulb = torch.sum(-sharpen_probs_ulb *
                                                     log_pred,
                                                     dim=1).mean()

                            loss += loss_ulb * lambda_u
                        else:  # pseudo label
                            if cfg['focal_loss']:  # Focal loss
                                loss_ulb = F.cross_entropy(outputs_ulb_wa,
                                                           preds_ulb,
                                                           reduction='none')
                                pt = torch.exp(-loss_ulb)
                                loss_ulb = ((
                                    (1 - pt)**cfg['gamma'] * loss_ulb) *
                                            mask).mean()
                            else:  # Previous loss
                                loss_ulb = (F.cross_entropy(outputs_ulb_wa,
                                                            preds_ulb,
                                                            reduction='none') *
                                            mask).mean()

                            mask_ratio.append(mask.mean().item())
                            loss += loss_ulb * lambda_u

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()

                        if cfg['use_tpu']:
                            xm.optimizer_step(optimizer)
                        else:
                            optimizer.step()

                # Calculate loss and metrics per the batch
                if purpose == 'baseline' or phase == 'test':
                    epoch_loss += loss.item() * size
                else:  # FixMatch
                    epoch_loss += loss_lb.item() *size\
                                    + loss_ulb.item() * lambda_u * size * cfg['mu']

                if not cfg['use_tpu'] or cfg['use_tpu'] and phase != 'train':
                    batch_metrics = update_batch_metrics(
                        batch_metrics, preds, labels, class_names)

            if phase == 'train' and scheduler:
                scheduler.step()

            # Calcluate the metrics (e.g. Accuracy) per the epoch
            if not cfg['use_tpu'] or cfg['use_tpu'] and phase != 'train':
                epoch_metrics = get_epoch_metrics(epoch_loss, dataset_sizes,
                                                  phase, class_names,
                                                  batch_metrics, metric_types)
                print_metrics(epoch_metrics,
                              metric_targets,
                              cfg,
                              phase=phase,
                              mask_ratio=mask_ratio)

            # Add prediction results per the epoch
            if phase != 'train':
                epoch_metrics_list.append(epoch_metrics)

        # Check early stopping
        if phase == 'test' and is_early:
            early_stopping(epoch_metrics['loss']['all'], model)
            if early_stopping.early_stop:
                print("Early stopping!!")
                break

    if not cfg['use_tpu'] or cfg[
            'use_tpu'] and phase != 'train' and xm.is_master_ordinal():
        # Extract best case index
        best_acc = (-1, -1)  # (idx, acc)
        for e_met_idx, e_met in enumerate(epoch_metrics_list):
            if e_met['acc']['all'] > best_acc[1]:
                best_acc = ((e_met_idx, e_met['acc']['all']))
        best_acc_idx = best_acc[0]

        # Set best metrics based on recent 5 epochs metrics
        for metric_type in metric_types:  # e.g. ['acc', 'ppv', ...]
            for metric_target in metric_targets:  # e.g. ['all', 'covid-19', ...]
                # Accuracy couldn't calculate for each class
                if metric_type == 'acc' and metric_target in class_names:
                    continue

                best_metrics[metric_type][metric_target] = \
                    epoch_metrics_list[best_acc_idx][metric_type][metric_target]

        print_metrics(best_metrics, metric_targets, cfg, phase='Best results')

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('-' * 20, '\n')

    return model, best_metrics
Example #13
0
def train_mnist(flags,
                training_started=None,
                dynamic_graph=False,
                fetch_often=False):
    torch.manual_seed(1)

    if flags.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(flags.batch_size, 1, 28, 28),
                  torch.zeros(flags.batch_size, dtype=torch.int64)),
            sample_count=600000 // flags.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(flags.batch_size, 1, 28, 28),
                  torch.zeros(flags.batch_size, dtype=torch.int64)),
            sample_count=100000 // flags.batch_size // xm.xrt_world_size())
    else:
        train_dataset = datasets.MNIST(os.path.join(flags.datadir,
                                                    str(xm.get_ordinal())),
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        test_dataset = datasets.MNIST(os.path.join(flags.datadir,
                                                   str(xm.get_ordinal())),
                                      train=False,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=flags.batch_size,
            sampler=train_sampler,
            drop_last=flags.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=flags.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=flags.batch_size,
            drop_last=flags.drop_last,
            shuffle=False,
            num_workers=flags.num_workers)

    # Scale learning rate to num cores
    lr = flags.lr * xm.xrt_world_size()

    device = xm.xla_device()
    model = MNIST().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(flags.logdir)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum)
    loss_fn = nn.NLLLoss()

    server = xp.start_server(flags.profiler_port)

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            if dynamic_graph:
                # testing purpose only: dynamic batch size and graph.
                index = max(-step, -flags.batch_size + 1)  # non-empty
                data, target = data[:-index, :, :, :], target[:-index]
            if step >= 15 and training_started:
                # testing purpose only: set event for synchronization.
                training_started.set()

            with xp.StepTrace('train_mnist', step_num=step):
                with xp.Trace('build_graph'):
                    optimizer.zero_grad()
                    output = model(data)
                    loss = loss_fn(output, target)
                    loss.backward()
                xm.optimizer_step(optimizer)
                if fetch_often:
                    # testing purpose only: fetch XLA tensors to CPU.
                    loss_i = loss.item()
                tracker.add(flags.batch_size)
                if step % flags.log_steps == 0:
                    xm.add_step_closure(_train_update,
                                        args=(device, step, loss, tracker,
                                              writer))

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            with xp.StepTrace('test_mnist'):
                output = model(data)
                pred = output.max(1, keepdim=True)[1]
                correct += pred.eq(target.view_as(pred)).sum()
                total_samples += data.size()[0]

        accuracy = 100.0 * correct.item() / total_samples
        accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)
    accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, flags.num_epochs + 1):
        xm.master_print('Epoch {} train begin {}'.format(
            epoch, test_utils.now()))
        train_loop_fn(train_device_loader)
        xm.master_print('Epoch {} train end {}'.format(epoch,
                                                       test_utils.now()))

        accuracy = test_loop_fn(test_device_loader)
        xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
            epoch, test_utils.now(), accuracy))
        max_accuracy = max(accuracy, max_accuracy)
        test_utils.write_to_summary(writer,
                                    epoch,
                                    dict_to_write={'Accuracy/test': accuracy},
                                    write_xla_metrics=True)
        if flags.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
    return max_accuracy
Example #14
0
def train_imagenet():
    print("==> Preparing data..")
    img_dim = get_model_property("img_dim")
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(
                torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            ),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size(),
        )
        if FLAGS.validate:
            test_loader = xu.SampleGenerator(
                data=(
                    torch.zeros(FLAGS.test_set_batch_size, 3, img_dim,
                                img_dim),
                    torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64),
                ),
                sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size(),
            )
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, "train"),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]),
        )
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        if FLAGS.validate:
            test_dataset = torchvision.datasets.ImageFolder(
                os.path.join(FLAGS.datadir, "val"),
                # Matches Torchvision's eval transforms except Torchvision uses size
                # 256 resize for all models both here and in the train loader. Their
                # version crashes during training on 299x299 images, e.g. inception.
                transforms.Compose([
                    transforms.Resize(resize_dim),
                    transforms.CenterCrop(img_dim),
                    transforms.ToTensor(),
                    normalize,
                ]),
            )

        train_sampler, test_sampler = None, None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            if FLAGS.validate:
                test_sampler = torch.utils.data.distributed.DistributedSampler(
                    test_dataset,
                    num_replicas=xm.xrt_world_size(),
                    rank=xm.get_ordinal(),
                    shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers,
        )
        if FLAGS.validate:
            test_loader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=FLAGS.test_set_batch_size,
                sampler=test_sampler,
                drop_last=FLAGS.drop_last,
                shuffle=False,
                num_workers=FLAGS.num_workers,
            )

    device = xm.xla_device()
    model = get_model_property("model_fn")().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(),
                          lr=FLAGS.lr,
                          momentum=FLAGS.momentum,
                          weight_decay=1e-4)
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         xm.xrt_world_size())
    lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
        optimizer,
        scheduler_type=getattr(FLAGS, "lr_scheduler_type", None),
        scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None),
        scheduler_divide_every_n_epochs=getattr(
            FLAGS, "lr_scheduler_divide_every_n_epochs", None),
        num_steps_per_epoch=num_training_steps_per_epoch,
        summary_writer=writer,
    )
    loss_fn = nn.CrossEntropyLoss()
    scaler = GradScaler()

    def train_loop_fn(loader, epoch):
        if FLAGS.fine_grained_metrics:
            epoch_start_time = time.time()
            step_latency_tracker, bwd_latency_tracker, fwd_latency_tracker = [], [], []
        else:
            tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            if FLAGS.fine_grained_metrics:
                step_start_time = time.time()
            optimizer.zero_grad()
            if FLAGS.fine_grained_metrics:
                fwd_start_time = time.time()
            with autocast():
                output = model(data)
                loss = loss_fn(output, target)
            if FLAGS.fine_grained_metrics:
                fwd_end_time = time.time()
                fwd_latency = fwd_end_time - fwd_start_time

                bwd_start_time = time.time()
            scaler.scale(loss).backward()
            gradients = xm._fetch_gradients(optimizer)
            xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
            scaler.step(optimizer)
            scaler.update()
            xm.mark_step()
            if lr_scheduler:
                lr_scheduler.step()
            if FLAGS.fine_grained_metrics:
                bwd_end_time = time.time()
                bwd_latency = bwd_end_time - bwd_start_time

                step_latency = bwd_end_time - step_start_time
                step_latency_tracker.append(step_latency)
                bwd_latency_tracker.append(bwd_latency)
                fwd_latency_tracker.append(fwd_latency)
            else:
                tracker.add(FLAGS.batch_size)
            if step % FLAGS.log_steps == 0:
                if FLAGS.fine_grained_metrics:
                    print('FineGrainedMetrics :: Epoch={} Step={} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\
                                                epoch, step, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
                else:
                    # _train_update(device, step, loss, tracker, epoch, writer)
                    xm.add_step_closure(_train_update,
                                        args=(device, step, loss, tracker,
                                              epoch, writer))
        if FLAGS.fine_grained_metrics:
            epoch_end_time = time.time()
            epoch_latency = epoch_end_time - epoch_start_time
            print('FineGrainedMetrics :: Epoch={} Epoch(s)={:.} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\
                                            epoch, epoch_latency, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))

    def test_loop_fn(loader, epoch):
        total_samples, correct = 0, 0
        model.eval()
        for step, (data, target) in enumerate(loader):
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]
            if step % FLAGS.log_steps == 0:
                test_utils.print_test_update(device, None, epoch, step)
                # xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step))
        accuracy = 100.0 * correct.item() / total_samples
        accuracy = xm.mesh_reduce("test_accuracy", accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    if FLAGS.validate:
        test_device_loader = pl.MpDeviceLoader(test_loader, device)
        accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        xm.master_print("Epoch {} train begin {}".format(
            epoch, test_utils.now()))
        train_loop_fn(train_device_loader, epoch)
        xm.master_print("Epoch {} train end {}".format(epoch,
                                                       test_utils.now()))
        if FLAGS.validate:
            accuracy = test_loop_fn(test_device_loader, epoch)
            xm.master_print("Epoch {} test end {}, Accuracy={:.2f}".format(
                epoch, test_utils.now(), accuracy))
            max_accuracy = max(accuracy, max_accuracy)
            test_utils.write_to_summary(
                writer,
                epoch,
                dict_to_write={"Accuracy/test": accuracy},
                write_xla_metrics=True)
        if FLAGS.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    if FLAGS.validate:
        xm.master_print("Max Accuracy: {:.2f}%".format(max_accuracy))
    return max_accuracy if FLAGS.validate else None
Example #15
0
def main(index):
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--train_data_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The input training data file (a text file).")
    parser.add_argument("--reload_data_file",
                        default=None,
                        type=int,
                        help="Reload dataset every X epoch")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--eval_data_file",
        default=None,
        type=str,
        help=
        "An optional input evaluation data file to evaluate the perplexity on (a text file)."
    )

    parser.add_argument("--model_type",
                        default="bert",
                        type=str,
                        help="The model architecture to be fine-tuned.")
    parser.add_argument(
        "--model_name_or_path",
        default="bert-base-cased",
        type=str,
        help="The model checkpoint for weights initialization.")

    parser.add_argument(
        "--mlm",
        action='store_true',
        help=
        "Train with masked-language modeling loss instead of language modeling."
    )
    parser.add_argument(
        "--mlm_probability",
        type=float,
        default=0.15,
        help="Ratio of tokens to mask for masked language modeling loss")

    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help=
        "Optional pretrained config name or path if not the same as model_name_or_path"
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help=
        "Optional pretrained tokenizer name or path if not the same as model_name_or_path"
    )
    parser.add_argument("--tokenizer_class",
                        default="",
                        type=str,
                        help="Optional pretrained tokenizer clas")
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)"
    )
    parser.add_argument(
        "--block_size",
        default=-1,
        type=int,
        help="Optional input sequence length after tokenization."
        "The training dataset will be truncated in block of this size for training."
        "Default to the model max input length for single sentence inputs (take into account special tokens)."
    )
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training",
        action='store_true',
        help="Run evaluation during training at each logging step.")
    parser.add_argument('--eval_steps',
                        type=int,
                        default=100,
                        help="Evaluate every X updates steps.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=4,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=4,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for optimizer.")
    parser.add_argument("--sgd",
                        action='store_true',
                        help="Use SGD instead of Adam.")

    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-6,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=1.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument("--warmup_samples",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_samples.")
    parser.add_argument("--lr_decay",
                        action='store_true',
                        help="Decay LR using get_linear_schedule_with_warmup.")
    parser.add_argument(
        "--lr_cosine",
        action='store_true',
        help="LR using get_cosine_with_hard_restarts_schedule_with_warmup.")

    parser.add_argument(
        "--unfreeze_level",
        default=-1,
        type=int,
        help="If > 0: freeze all layers except few first and last.")

    parser.add_argument('--logging_steps',
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps',
                        type=int,
                        default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        '--save_total_limit',
        type=int,
        default=None,
        help=
        'Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default'
    )
    parser.add_argument(
        "--eval_all_checkpoints",
        action='store_true',
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number"
    )
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--first_run', action='store_true', help="Cache init")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit/mixed precision instead of 32-bit")

    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="For distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="For distant debugging.")
    args = parser.parse_args()
    args.local_rank = index

    if args.model_type in ["bert", "roberta", "distilbert"] and not args.mlm:
        raise ValueError(
            "BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
            "flag (masked language modeling).")
    if args.eval_data_file is None and args.do_eval:
        raise ValueError(
            "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
            "or remove the --do_eval argument.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir
    ) and args.do_train and not args.overwrite_output_dir:
        raise ValueError(
            f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    args.n_gpu = xm.xrt_world_size()
    args.device = xm.xla_device()

    # Setup logging
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if xm.is_master_ordinal() else logging.WARN)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s",
        args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1))

    # Set seed
    # That is actually very important in case of distributed environment (like TPU). You need same dataset on every node/process.
    # If you have randomness in dataset creation (like I do) you need to set the same seed in every process.
    set_seed(args)

    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    if os.path.exists(os.path.join(args.output_dir, WEIGHTS_NAME)):
        args.model_name_or_path = args.output_dir
    else:
        args.first_run = True

    # load model from web in single thread or file will be corrupted.
    lock = FileLock("the.lock") if args.first_run else contextlib.suppress()

    with lock:
        config = config_class.from_pretrained(
            args.config_name if args.config_name else args.model_name_or_path)
        if args.tokenizer_class:
            tokenizer_class = globals()[args.tokenizer_class]
        tokenizer = tokenizer_class.from_pretrained(
            args.tokenizer_name
            if args.tokenizer_name else args.model_name_or_path,
            do_lower_case=args.do_lower_case)
        if args.block_size <= 0:
            args.block_size = tokenizer.max_len_single_sentence  # Our input block size will be the max possible for the model
        args.block_size = min(args.block_size,
                              tokenizer.max_len_single_sentence)
        model = model_class.from_pretrained(
            args.model_name_or_path,
            from_tf=bool('.ckpt' in args.model_name_or_path),
            config=config)

    if args.fp16:
        model = model2half(model)

    model = model.to(args.device)
    # see https://github.com/pytorch/xla/issues/1245
    model.tie_weights()

    def req_len(model):
        return len([
            param for item in flatten_model(model)
            for param in item.parameters() if param.requires_grad
        ])

    # freeze all layers but few first and last
    if args.unfreeze_level >= 0:
        b_req_len = req_len(model)
        flat = flatten_model(model)
        flat = [item for item in flat if list(item.parameters())]
        i_start = 3
        i_end = 1
        need_grads = set(flat[:i_start + args.unfreeze_level * 3]) | set(
            flat[-(i_end + args.unfreeze_level * 3):])
        for item in flat:
            requires_grad(item, item in need_grads)
        log_info(
            f"Num of layers before {b_req_len}, after freeze {req_len(model)}")

    log_info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
        train(args, model, tokenizer)
def tpu_training_loop(index):
    torch.set_default_tensor_type('torch.FloatTensor')
    #To decrease exploing RAM usage, only load and transfer one model at time
    lock_file = "tpu.lock"
    fd = open(lock_file, "w")
    fcntl.lockf(fd, fcntl.LOCK_EX)

    model_class = GPT2LMHeadModel

    model = model_class.from_pretrained("gpt2")
    tokenizer_class = GPT2Tokenizer
    tokenizer = tokenizer_class.from_pretrained("gpt2", do_lower_case=False)

    device = xm.xla_device()

    logger_is_me = False
    if xm.is_master_ordinal():
        logger_is_me = True
        from torch.utils.tensorboard import SummaryWriter
        writer = SummaryWriter()

    special_tokens = {
        "additional_special_tokens": [
            "<TITLE_START>", "<TITLE_END>", "<INSTR_START>", "<NEXT_INSTR>",
            "<INSTR_END>", "<INGR_START>", "<NEXT_INGR>", "<INGR_END>",
            "<RECIPE_START>", "<RECIPE_END>", "<INPUT_START>", "<INPUT_END>",
            "<NEXT_INPUT>"
        ]
    }

    tokenizer.add_special_tokens(special_tokens)
    model.resize_token_embeddings(len(tokenizer))

    train_dataset = TextDataset(file_path="train")
    test_dataset = TextDataset(file_path="test")

    train_sampler = DistributedSampler(train_dataset,
                                       num_replicas=xm.xrt_world_size(),
                                       rank=xm.get_ordinal(),
                                       shuffle=True)
    test_sampler = DistributedSampler(test_dataset,
                                      num_replicas=xm.xrt_world_size(),
                                      rank=xm.get_ordinal(),
                                      shuffle=False)

    #PARAMS!!
    train_batch_size = 4
    test_batch_size = 4

    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=train_batch_size)
    test_dataloader = DataLoader(test_dataset,
                                 sampler=test_sampler,
                                 batch_size=test_batch_size)

    model.train().to(device)

    import gc
    gc.collect()

    fcntl.lockf(fd, fcntl.LOCK_UN)

    gradient_steps = 1
    epochs = 1
    t_total = len(train_dataloader) // gradient_steps

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]

    # one optimizer and scheduler per TPU core. Both objects are saved in `context` to be reused the next epoch
    lr = 5e-5 * xm.xrt_world_size()
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=0,
                                                num_training_steps=t_total)

    tracker = xm.RateTracker()

    # PARAMS V2!!!
    gradient_steps = 1
    logging_steps = 100
    validation_steps = 1000

    optimizer.zero_grad()

    def single_epoch(big_step, epoch):
        train_sampler.set_epoch(epoch)
        para_loader = pl.ParallelLoader(train_dataloader, [device])
        for step, batch in enumerate(para_loader.per_device_loader(device)):
            inputs, labels = (batch, batch)
            model.train()
            outputs = model(inputs, labels=labels)
            loss = outputs[0]

            loss = loss / gradient_steps

            loss.backward()
            tracker.add(1)

            if (step + 1) % gradient_steps == 0:
                xm.optimizer_step(optimizer)
                scheduler.step()
                optimizer.zero_grad()
                big_step += 1

                if logger_is_me and (big_step + 1) % logging_steps == 0:
                    xm.add_step_closure(_train_update,
                                        args=(device, big_step, loss, tracker,
                                              scheduler, writer))

                if (big_step + 1) % validation_steps == 0:
                    perplexity = evaluate(model, test_dataloader, device)
                    if logger_is_me:
                        print("Validation loss: ", perplexity)
                        writer.add_scalar("Validation loss", perplexity,
                                          big_step)
        return big_step

    big_step = 0
    #Always pretend to have one more epoch to do, otherwise model won't get saved
    for i in range(1, 6):
        print("Epoch: " + str(i))
        big_step = single_epoch(big_step, i)
        if logger_is_me:
            output_dir = "gpt2-refined-epoch-" + str(i)
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            save_model(model, output_dir)
            tokenizer.save_pretrained(output_dir)
            print("Model saved")
Example #17
0
def log_info(*args, **kwargs):
    if xm.is_master_ordinal():
        logger.info(*args, **kwargs)
Example #18
0
def train(args, train_dataset, model, tokenizer):
    is_master = xm.is_master_ordinal()
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        sampler=train_sampler,
        num_workers=8,
        drop_last=True)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    # Train!  Total optimization steps
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("= %d", t_total)

    global_step = 1
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)
            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank
                            not in [-1, 0]) if is_master else range(
                                epochs_trained, int(args.num_train_epochs))
    # Added here for reproductibility
    set_seed(args)

    for _ in train_iterator:
        para_loader = pl.ParallelLoader(train_dataloader, [args.device])
        epoch_iterator = tqdm(
            para_loader.per_device_loader(args.device),
            desc="Iteration",
        ) if is_master else para_loader.per_device_loader(args.device)
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            # batch = tuple(t.to(args.device) for t in batch)

            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "start_positions": batch[3],
                "end_positions": batch[4],
            }

            if args.model_type in [
                    "xlm", "roberta", "distilbert", "camembert"
            ]:
                del inputs["token_type_ids"]

            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
                if args.version_2_with_negative:
                    inputs.update({"is_impossible": batch[7]})
                if hasattr(model, "config") and hasattr(
                        model.config, "lang2id"):
                    inputs.update({
                        "langs":
                        (torch.ones(batch[0].shape, dtype=torch.int64) *
                         args.lang_id).to(args.device)
                    })

            outputs = model(**inputs)
            # model outputs are always tuple in transformers (see doc)

            loss = outputs[0]

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:

                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

                xm.optimizer_step(optimizer)
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # Log metrics
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.local_rank == -1 and args.evaluate_during_training:
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                # # Save model checkpoint
                # if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                #     output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                #     if not os.path.exists(output_dir):
                #         os.makedirs(output_dir)
                #     # Take care of distributed/parallel training
                #     model_to_save = model.module if hasattr(model, "module") else model
                #     model_to_save.save_pretrained(output_dir)
                #     tokenizer.save_pretrained(output_dir)
                #
                #     torch.save(args, os.path.join(output_dir, "training_args.bin"))
                #     logger.info("Saving model checkpoint to %s", output_dir)
                #
                #     torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                #     torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                #     logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
def valid(args, valid_loader, model, device, tokenizer, epoch, f, max_seq_len):
    total_loss = AverageMeter()
    losses1 = AverageMeter()  # start
    losses2 = AverageMeter()  # end
    accuracies1 = AverageMeter()  # start
    accuracies2 = AverageMeter()  # end

    jaccard_scores = AverageMeter()

    model.eval()

    with torch.no_grad():
        t = tqdm(valid_loader, disable=not xm.is_master_ordinal())
        for step, d in enumerate(t):

            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            token_type_ids = d["token_type_ids"].to(device)
            start_position = d["start_position"].to(device)
            end_position = d["end_position"].to(device)

            sentiment_label = d["sentiment_label"].to(device)

            logits1, logits2 = model(input_ids=input_ids,
                                     attention_mask=attention_mask,
                                     token_type_ids=token_type_ids,
                                     position_ids=None,
                                     head_mask=None)

            #y_true = (start_position, end_position)
            loss1, loss2 = loss_fn((logits1, logits2),
                                   (start_position, end_position))
            loss = loss1 + loss2

            #max_seq_len = 256
            #loss = Closs.loss_fn(logits1, logits2, start_position, end_position,device, max_seq_len)

            acc1, n_position1 = get_position_accuracy(logits1, start_position)
            acc2, n_position2 = get_position_accuracy(logits2, end_position)

            total_loss.update(loss.item(), n_position1)
            losses1.update(loss1.item(), n_position1)
            losses2.update(loss2.item(), n_position2)
            accuracies1.update(acc1, n_position1)
            accuracies2.update(acc2, n_position2)

            jac_score = calculate_jaccard_score(features_dict=d,
                                                start_logits=logits1,
                                                end_logits=logits2,
                                                tokenizer=tokenizer)

            jaccard_scores.update(jac_score)

            print_loss = xm.mesh_reduce("vloss_reduce", total_loss.avg,
                                        reduce_fn)
            print_jac = xm.mesh_reduce("jac_reduce", jaccard_scores.avg,
                                       reduce_fn)
            print_acc1 = xm.mesh_reduce("vacc1_reduce", accuracies1.avg,
                                        reduce_fn)
            print_acc2 = xm.mesh_reduce("vacc2_reduce", accuracies2.avg,
                                        reduce_fn)

            t.set_description(
                f"Eval E:{epoch+1} - Loss:{print_loss:0.2f} - Jac:{print_jac:0.2f} - acc1:{print_acc1:0.2f} - acc2:{print_acc2:0.2f}"
            )

    #print("Valid Jaccard Score : ", jaccard_scores.avg)
    log_ = f"Epoch : {epoch+1} - valid_loss : {total_loss.avg} - \n\
    valid_loss1 : {losses1.avg} - \valid_loss2 : {losses2.avg} - \n\
    valid_acc1 : {accuracies1.avg} - \valid_acc2 : {accuracies2.avg} "

    f.write(log_ + "\n\n")
    f.flush()

    return jaccard_scores.avg, total_loss.avg
Example #20
0
def map_fn(index, args):
    """ for tpu """
    # Setup tpu
    device = xm.xla_device()
    args.device = device

    is_master = xm.is_master_ordinal()
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if is_master else logging.DEBUG,
    )

    # Set seed
    set_seed(args)

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )

    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)
    # Training
    if args.do_train:
        train_dataset = load_and_cache_examples(args,
                                                tokenizer,
                                                evaluate=False,
                                                output_examples=False)
        logger.info(" data load finished! ")
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step,
                    tr_loss)

    # Save the trained model and the tokenizer
    if args.do_train and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        # Take care of distributed/parallel training
        model_to_save = model.module if hasattr(model, "module") else model
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

        # Load a trained model and vocabulary that you have fine-tuned
        model = model_class.from_pretrained(
            args.output_dir)  # , force_download=True)
        tokenizer = tokenizer_class.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)
        model.to(args.device)
    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            "optimizer",
            lambda: optim.SGD(
                model.parameters(),
                lr=FLAGS.lr,
                momentum=FLAGS.momentum,
                weight_decay=1e-4,
            ),
        )
        lr_scheduler = context.getattr_or(
            "lr_scheduler",
            lambda: schedulers.wrap_optimizer_with_scheduler(
                optimizer,
                scheduler_type=getattr(FLAGS, "lr_scheduler_type", None),
                scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None),
                scheduler_divide_every_n_epochs=getattr(
                    FLAGS, "lr_scheduler_divide_every_n_epochs", None
                ),
                num_steps_per_epoch=num_training_steps_per_epoch,
                summary_writer=writer if xm.is_master_ordinal() else None,
            ),
        )
        tracker = xm.RateTracker()
        model.train()
        total_samples = 0
        correct = 0
        top5_accuracys = 0
        losses = 0
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]
            top5_accuracys += topk_accuracy(output, target, topk=5)
            loss = loss_fn(output, target)
            loss.backward()
            losses += loss.item()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                print(
                    "[{}]({}) Loss={:.5f} Top-1 ACC = {:.2f} Rate={:.2f} GlobalRate={:.2f} Time={}".format(
                        str(device),
                        x,
                        loss.item(),
                        (100.0 * correct / total_samples).item(),
                        tracker.rate(),
                        tracker.global_rate(),
                        time.asctime(),
                    )
                )

            if lr_scheduler:
                lr_scheduler.step()
        return (
            losses / (x + 1),
            (100.0 * correct / total_samples).item(),
            (top5_accuracys / (x + 1)).item(),
        )
def train_imagenet():
    print('==> Preparing data..')
    img_dim = get_model_property('img_dim')
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler, test_sampler = None, None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            persistent_workers=True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            sampler=test_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            persistent_workers=True,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    device = xm.xla_device()
    model = get_model_property('model_fn')()
    # Wrap the model with FSDP
    # You may wrap all, a subset, or none of the sub-modules with inner FSDPs
    # - to implement ZeRO-2, wrap none of the sub-modules
    # - to implement ZeRO-3, wrap all of the sub-modules (nested FSDP)
    # - you may wrap sub-modules at different granularity (e.g. at each resnet
    #   stage or each residual block or each conv layer).
    fsdp_wrap = lambda m: FSDP(m.to(device),
                               compute_dtype=getattr(torch, FLAGS.compute_dtype
                                                     ),
                               fp32_reduce_scatter=FLAGS.fp32_reduce_scatter,
                               flatten_parameters=FLAGS.flatten_parameters)
    # Apply gradient checkpointing to sub-modules if specified
    grad_ckpt_wrap = checkpoint_module if FLAGS.use_gradient_checkpointing else (
        lambda x: x)
    if FLAGS.use_nested_fsdp:
        # Here we apply inner FSDP at the level of child modules for ZeRO-3, which
        # corresponds to different stages in resnet (i.e. Stage 1 to 5).
        for submodule_name, submodule in model.named_children():
            if sum(p.numel() for p in submodule.parameters()) == 0:
                # Skip those submodules without parameters (i.e. no need to shard them)
                continue
            # Note: wrap with `checkpoint_module` first BEFORE wrapping with FSDP
            m_fsdp = fsdp_wrap(grad_ckpt_wrap(getattr(model, submodule_name)))
            setattr(model, submodule_name, m_fsdp)
    # Always wrap the base model with an outer FSDP
    model = fsdp_wrap(model)

    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(),
                          lr=FLAGS.lr,
                          momentum=FLAGS.momentum,
                          weight_decay=1e-4)
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         xm.xrt_world_size())
    lr_scheduler = schedulers.WarmupAndExponentialDecayScheduler(
        optimizer,
        num_steps_per_epoch=num_training_steps_per_epoch,
        divide_every_n_epochs=FLAGS.lr_scheduler_divide_every_n_epochs,
        divisor=FLAGS.lr_scheduler_divisor,
        num_warmup_epochs=FLAGS.num_warmup_epochs,
        summary_writer=writer)
    loss_fn = nn.CrossEntropyLoss()

    def train_loop_fn(loader, epoch):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()  # do not reduce gradients on sharded params
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()
            if step % FLAGS.log_steps == 0:
                xm.add_step_closure(_train_update,
                                    args=(device, step, loss, tracker, epoch,
                                          writer))

    def test_loop_fn(loader, epoch):
        total_samples, correct = 0, 0
        model.eval()
        for step, (data, target) in enumerate(loader):
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]
            if step % FLAGS.log_steps == 0:
                xm.add_step_closure(test_utils.print_test_update,
                                    args=(device, None, epoch, step))
        accuracy = 100.0 * correct.item() / total_samples
        accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)
    accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        xm.master_print('Epoch {} train begin {}'.format(
            epoch, test_utils.now()))
        train_loop_fn(train_device_loader, epoch)
        xm.master_print('Epoch {} train end {}'.format(epoch,
                                                       test_utils.now()))
        run_eval = ((not FLAGS.test_only_at_end
                     and epoch % FLAGS.eval_interval == 0)
                    or epoch == FLAGS.num_epochs)
        if run_eval:
            accuracy = test_loop_fn(test_device_loader, epoch)
            xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
                epoch, test_utils.now(), accuracy))
            max_accuracy = max(accuracy, max_accuracy)
            test_utils.write_to_summary(
                writer,
                epoch,
                dict_to_write={'Accuracy/test': accuracy},
                write_xla_metrics=True)
        if FLAGS.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
    return max_accuracy
Example #23
0
def train(rank, args):
    print('enter train @ %s' % (rank), flush=True)
    args.rank = rank
    torch.manual_seed(42)

    tokenizer = get_tokenizer(args)
    args.vocab_size = tokenizer._tokenizer.get_vocab_size()

    train_dataset = get_dataset(args)

    if args.total_num_updates < 100:
        args.total_num_updates = len(train_dataset) * args.total_num_updates

    if args.warmup_updates < 1:
        args.warmup_updates = int(args.total_num_updates * args.warmup_updates)
    else:
        args.warmup_updates = int(args.warmup_updates)

    train_sampler = None
    if args.gpus:
        dist.init_process_group('nccl', rank=rank, world_size=args.world_size)
        if args.gpus > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=args.gpus,
                rank=rank,
                shuffle=False)

    else:
        rank = xm.get_ordinal()
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=rank,
                shuffle=False)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size
        if not hasattr(train_dataset, '__getbatch__') else None,
        sampler=train_sampler,
        pin_memory=True,
        shuffle=False,
        num_workers=args.num_workers)

    eval_loader = None
    if args.eval_dir:

        eval_sampler = None
        if args.gpus:
            dist.init_process_group('nccl',
                                    rank=rank,
                                    world_size=args.world_size)
            if args.gpus > 1:
                traieval_samplern_sampler = torch.utils.data.distributed.DistributedSampler(
                    train_dataset,
                    num_replicas=args.gpus,
                    rank=rank,
                    shuffle=False)

        else:
            rank = xm.get_ordinal()
            if xm.xrt_world_size() > 1:
                eval_sampler = torch.utils.data.distributed.DistributedSampler(
                    train_dataset,
                    num_replicas=xm.xrt_world_size(),
                    rank=rank,
                    shuffle=False)

        eval_dataset = get_eval_dataset(args)
        eval_loader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=args.batch_size
            if not hasattr(train_dataset, '__getbatch__') else None,
            sampler=eval_sampler,
            pin_memory=True,
            shuffle=False,
            num_workers=args.num_workers)

    if args.gpus:
        assert apex_enabled
        torch.cuda.set_device(rank)

        ##########################
        ##
        ##  Model Creation
        ##
        ##########################
        model = get_model(args)

        model.cuda(rank)

        device = torch.device('cuda:' + str(rank))

        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = apex.optimizers.FusedAdam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay),

            # use this function to set extra optimizer arguments,
            # see model_get_parameters
            betas=(0.9, 0.999),
            eps=1e-6,
            lr=args.lr,
            weight_decay=args.weight_decay)

        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        model = DDP(model)
        batches = train_loader

    else:
        assert tpu_enabled
        device = xm.xla_device()

        ##########################
        ##
        ##  Model Creation
        ##
        ##########################

        model = get_model(args)

        ##########################
        ##
        ##  For shared parameters, TPU requires modules to be tied after .to(device)
        ##  So we first find the shared parameters first
        ##
        ##########################

        shared_parameters = {
            e[0]: e[1:]
            for e in _catalog_shared_params(model)
        }

        model.to(device)

        do_share_parameters_again(model, shared_parameters, log=rank == 0)

        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = optim.Adam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay),

            # use this function to set extra optimizer arguments,
            # see model_get_parameters
            lr=args.lr,
            weight_decay=args.weight_decay)

        writer = None
        if xm.is_master_ordinal():
            writer = test_utils.get_summary_writer(args.save_dir)

        xm.rendezvous("load_checkpoint")  # wait for all workers
        xm.mark_step()

        # tracker = xm.RateTracker()
    if args.restore_file:
        states = torch.load(args.restore_file, map_location=device)
        for k, v in list(states.items()):
            if k.startswith('module.'):
                del states[k]
                k = k[7:]
                states[k] = v
            if k.endswith('position_ids'):
                del states[k]
                states[k[:-12] + 'position_embeddings'] = v
        try:
            model.load_state_dict(states)
        except Exception as err:
            import traceback
            traceback.print_exc()
            model.load_state_dict(states, strict=False)

    model.train()

    if args.anomaly_detection and rank == 0:
        torch.set_anomaly_enabled(True)

    ##########################
    ##
    ##  Init LR Scheduler
    ##
    ##########################

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_updates,
        num_training_steps=args.total_num_updates,
    )

    step_i = 0

    err = None
    try:
        if rank == 0:
            pbar = tqdm(total=args.total_num_updates)
        while step_i < args.total_num_updates:
            if not args.gpus:
                batches = pl.ParallelLoader(train_loader,
                                            [device]).per_device_loader(device)
            for sample in batches:
                step_i += 1
                if step_i > args.total_num_updates:
                    break

                report_step = step_i % args.log_interval == 0

                while True:  # the loop only for apex Gradient Overflow
                    optimizer.zero_grad()

                    total_loss, log = get_loss(model,
                                               sample,
                                               args=args,
                                               device=device,
                                               gpu=args.gpus,
                                               report=report_step)

                    if args.gpus:
                        default_optimizer_step = optimizer.step

                        with amp.scale_loss(total_loss,
                                            optimizer) as scaled_loss:
                            scaled_loss.backward()

                        # If Amp detects an overflow, it patches optimizer.step.  In other words, if optimizer.step
                        # was left unpatched, there was no overflow, and we don't need to replay.
                        if optimizer.step is default_optimizer_step:
                            optimizer.step()
                            break

                        optimizer.step(
                        )  # If an overflow was detected, "optimizer.step" is the patched call, which does
                        # nothing but restore optimizer.step to default_optimizer_step.
                        if rank == 0:
                            print(
                                "Overflowed, reducing loss scale and replaying batch.",
                                flush=True)

                    else:
                        total_loss.backward()
                        xm.optimizer_step(optimizer)
                        xm.mark_step()

                        break

                scheduler.step()

                if report_step:
                    if 'loss' not in log:
                        log['loss'] = total_loss

                    if args.gpus:
                        if rank == 0:
                            pbar.set_description(format_log(
                                log, log_formatter))
                    else:
                        xm.add_step_closure(_train_update,
                                            args=(log, log_formatter))

                    if args.report_metrics:
                        xm.master_print(met.metrics_report())

                if rank == 0:
                    pbar.update(1)

        if eval_loader is not None:
            model.eval()
            if not args.gpus:
                batches = pl.ParallelLoader(eval_loader,
                                            [device]).per_device_loader(device)
            with torch.no_grad():
                record = OrderedDict()

                for sample in batches:
                    evaluate(model,
                             sample,
                             args=args,
                             device=device,
                             record=record,
                             gpu=args.gpus,
                             report=report_step)

                post_evaluate(record, args=args)

            import json
            print('', flush=True)
            print(json.dumps(record), flush=True)
            print('', flush=True)

    except Exception as _err:
        err = _err
    finally:
        save_fn = os.path.join(args.save_dir, 'checkpoint_final.pt')
        folder = os.path.split(os.path.abspath(save_fn))[0]
        os.makedirs(folder, exist_ok=True)
        if rank == 0 and args.gpus:
            torch.save(model.state_dict(), save_fn)
            if err:
                raise err
        else:
            xm.save(model.state_dict(), save_fn)
            if err:
                raise err
Example #24
0
def train(i, num_gpus, rank, group_name, output_directory, epochs, learning_rate,
          sigma, iters_per_checkpoint, batch_size, seed, fp16_run,
          checkpoint_path, with_tensorboard):
    torch.manual_seed(seed)
    #torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======
    device = xm.xla_device()

    criterion = WaveGlowLoss(sigma)
    model = WaveGlow(**waveglow_config) #.cuda()

    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    if fp16_run:
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
                                                      optimizer)
        iteration += 1  # next iteration is iteration + 1
    
    model = model.to(device)
    
    trainset = Mel2Samp(**data_config)
    # =====START: ADDED FOR DISTRIBUTED======
    train_sampler = DistributedSampler(trainset, rank=xm.get_ordinal(), num_replicas=xm.xrt_world_size()) #if num_gpus > 1 else None
    # =====END:   ADDED FOR DISTRIBUTED======
    train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=True,
                              drop_last=True)

    xla_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
    
    # Get shared output_directory ready
    if xm.is_master_ordinal():
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)

    if with_tensorboard and xm.is_master_ordinal():
        from tensorboardX import SummaryWriter
        logger = SummaryWriter(os.path.join(output_directory, 'logs'))

    model.train()
    epoch_offset = max(0, int(iteration / len(train_loader)))
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, epochs):
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            optimizer.zero_grad()

            mel, audio = batch
            mel , audio=  mel.to(device), audio.to(device)#torch.autograd.Variable(mel.cuda())
            mel.requires_grad, audio.requires_grad = True, True #torch.autograd.Variable(audio.cuda())
            
            outputs = model((mel, audio))

            loss = criterion(outputs)
            
            if num_gpus > 1:
                reduced_loss = reduce_tensor(loss.data, num_gpus).item()
            else:
                reduced_loss = loss.item()

            if fp16_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
                if xm.is_master_ordinal() :
                    print("{}:\t{:.9f}".format(iteration, loss.item()))

            xm.optimizer_step(optimizer, barrier=True)
            
            #print("{}:\t{:.9f}".format(iteration, reduced_loss))
            if with_tensorboard and xm.is_master_ordinal() :
                logger.add_scalar('training_loss', reduced_loss, i + len(train_loader) * epoch)

            if (iteration % iters_per_checkpoint == 0):
                if xm.is_master_ordinal() :
                    checkpoint_path = "{}/waveglow_{}".format(
                        output_directory, iteration)
                    save_checkpoint(model, optimizer, learning_rate, iteration,
                                    checkpoint_path)

            iteration += 1
Example #25
0
def train_imagenet(state_dict):
  print('==> Preparing data..')
  img_dim = get_model_property('img_dim')
  if FLAGS.fake_data:
    train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
              torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
        sample_count=train_dataset_len // FLAGS.batch_size //
        xm.xrt_world_size())
    test_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
              torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
        sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
  else:
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_dataset = torchvision.datasets.ImageFolder(
        os.path.join(FLAGS.datadir, 'train'),
        transforms.Compose([
            transforms.RandomResizedCrop(img_dim),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_dataset_len = len(train_dataset.imgs)
    resize_dim = max(img_dim, 256)
    test_dataset = torchvision.datasets.ImageFolder(
        os.path.join(FLAGS.datadir, 'val'),
        # Matches Torchvision's eval transforms except Torchvision uses size
        # 256 resize for all models both here and in the train loader. Their
        # version crashes during training on 299x299 images, e.g. inception.
        transforms.Compose([
            transforms.Resize(resize_dim),
            transforms.CenterCrop(img_dim),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler, test_sampler = None, None
    if xm.xrt_world_size() > 1:
      train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)
      test_sampler = torch.utils.data.distributed.DistributedSampler(
          test_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=False)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=FLAGS.batch_size,
        sampler=train_sampler,
        drop_last=FLAGS.drop_last,
        shuffle=False if train_sampler else True,
        num_workers=FLAGS.num_workers)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=FLAGS.test_set_batch_size,
        sampler=test_sampler,
        drop_last=FLAGS.drop_last,
        shuffle=False,
        num_workers=FLAGS.num_workers)

  device = xm.xla_device()
  model = get_model_property('model_fn')()
  model.load_state_dict(state_dict)
  model = model.to(device)
  writer = None
  if xm.is_master_ordinal():
    writer = test_utils.get_summary_writer(FLAGS.logdir)
  optimizer = optim.SGD(
      model.parameters(),
      lr=FLAGS.lr,
      momentum=FLAGS.momentum,
      weight_decay=1e-4)
  num_training_steps_per_epoch = train_dataset_len // (
      FLAGS.batch_size * xm.xrt_world_size())
  lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
      optimizer,
      scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
      scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
      scheduler_divide_every_n_epochs=getattr(
          FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
      num_steps_per_epoch=num_training_steps_per_epoch,
      summary_writer=writer)
  loss_fn = nn.CrossEntropyLoss()

  def train_loop_fn(loader, epoch):
    tracker = xm.RateTracker()
    model.train()
    for step, (data, target) in enumerate(loader):
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS.batch_size)
      if lr_scheduler:
        lr_scheduler.step()
      if step % FLAGS.log_steps == 0:
        xm.add_step_closure(
            _train_update, args=(device, step, loss, tracker, epoch, writer))

  def test_loop_fn(loader, epoch):
    total_samples, correct = 0, 0
    model.eval()
    for step, (data, target) in enumerate(loader):
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum()
      total_samples += data.size()[0]
      if step % FLAGS.log_steps == 0:
        xm.add_step_closure(
            test_utils.print_test_update, args=(device, None, epoch, step))
    accuracy = 100.0 * correct.item() / total_samples
    # accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
    return accuracy

  train_device_loader = pl.MpDeviceLoader(train_loader, device)
  test_device_loader = pl.MpDeviceLoader(test_loader, device)
  accuracy, max_accuracy = 0.0, 0.0
  for epoch in range(1, FLAGS.num_epochs + 1):
    xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
    train_loop_fn(train_device_loader, epoch)
    xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))
    if not FLAGS.test_only_at_end or epoch == FLAGS.num_epochs:
      accuracy = test_loop_fn(test_device_loader, epoch)
      xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
          epoch, test_utils.now(), accuracy))
      max_accuracy = max(accuracy, max_accuracy)
      test_utils.write_to_summary(
          writer,
          epoch,
          dict_to_write={'Accuracy/test': accuracy},
          write_xla_metrics=True)
    if FLAGS.metrics_debug:
      xm.master_print(met.metrics_report())

  test_utils.close_summary_writer(writer)
  xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
  return max_accuracy
Example #26
0
def train(args, model, tokenizer):
    """ Train the model """
    if xm.is_master_ordinal():
        tb_writer = SummaryWriterP(args.output_dir)

    def summary_write(*args, **kwargs):
        if xm.is_master_ordinal():
            tb_writer.add_scalar(*args, **kwargs)

    args.train_batch_size = args.per_gpu_train_batch_size  #* max(1, args.n_gpu)

    train_dataloader = build_dataloader(args, tokenizer)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if p.requires_grad and not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if p.requires_grad and any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    # Scale learning rate to num cores
    #args.learning_rate = args.learning_rate * xm.xrt_world_size()
    if args.sgd:
        optimizer = SGD(optimizer_grouped_parameters, lr=args.learning_rate)
    else:
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    warmup_steps = args.warmup_samples // (args.train_batch_size *
                                           xm.xrt_world_size())
    if args.lr_decay:
        scheduler = get_linear_schedule_with_warmup(optimizer,
                                                    warmup_steps=warmup_steps,
                                                    t_total=t_total)
    elif args.lr_cosine:
        scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
            optimizer,
            warmup_steps=warmup_steps,
            t_total=t_total,
            cycles=args.num_train_epochs)
    else:
        scheduler = WarmupZeroSchedule(optimizer, warmup_steps=warmup_steps)

    # Train!
    tracker = xm.RateTracker()
    log_info("***** Running training *****")
    log_info("  Num Epochs = %d", args.num_train_epochs)
    log_info("  Instantaneous batch size per GPU = %d",
             args.per_gpu_train_batch_size)
    log_info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (xm.xrt_world_size() if args.local_rank != -1 else 1))
    log_info("  Gradient Accumulation steps = %d",
             args.gradient_accumulation_steps)
    log_info("  Total optimization steps = %d", t_total)

    try:
        with open(os.path.join(args.model_name_or_path, 'step.txt'), 'r') as c:
            global_step = int(c.readline())
    except OSError as e:
        global_step = 0

    moving_loss = MovingLoss(10000 // args.logging_steps)

    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=not xm.is_master_ordinal())
    try:
        for epoch in train_iterator:
            p_train_dataloader = pl.ParallelLoader(train_dataloader,
                                                   [args.device])
            epoch_iterator = tqdm(p_train_dataloader.per_device_loader(
                args.device),
                                  total=len(train_dataloader),
                                  desc="Iteration",
                                  disable=not xm.is_master_ordinal())

            model.train()
            for step, batch in enumerate(epoch_iterator):
                optimizer.zero_grad()
                inputs, labels = mask_tokens(
                    batch, tokenizer, args) if args.mlm else (batch, batch)
                outputs = model(
                    inputs, masked_lm_labels=labels) if args.mlm else model(
                        inputs, labels=labels)
                loss = outputs[
                    0]  # model outputs are always tuple in pytorch-transformers (see doc)

                if args.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                    xm.optimizer_step(optimizer, barrier=True)
                    scheduler.step()
                    global_step += 1
                    tracker.add(args.train_batch_size)

                    if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        ls = loss.item(
                        )  # weird. if you call loss.item() only in one process, the whole thing hangs. So call on every and log in one.
                        moving_loss.add(ls)
                        summary_write('lr',
                                      scheduler.get_last_lr()[0], global_step)
                        epoch_iterator.set_postfix(
                            MovingLoss=f'{moving_loss.loss:.2f}',
                            Perplexity=
                            f'{torch.exp(torch.tensor(moving_loss.loss)):.2f}')

                    if args.save_steps > 0 and global_step % args.save_steps == 0:
                        save_state(args, model, tokenizer, global_step)

                #if step >= 1023: # TPU seems to like consistent epoch lenght
                #    epoch_iterator.close()
                #    break

                if args.max_steps > 0 and step > args.max_steps:
                    epoch_iterator.close()
                    break

            # evaluate once in an epoch
            if args.evaluate_during_training:
                results = evaluate(args, model, tokenizer,
                                   f"checkpoint-{global_step}")
                log_info(f"Eval {results}")
                for key, value in results.items():
                    summary_write("eval_{}".format(key), value, global_step)

            # reload dataset every args.reload_data_file epochs
            if args.reload_data_file and (epoch +
                                          1) % args.reload_data_file == 0:
                train_dataloader = build_dataloader(args, tokenizer)

            # that's very slow on TPU
            #print_sample(model, tokenizer, args.device, args)

    except (KeyboardInterrupt, SystemExit):
        save_state(args, model, tokenizer, global_step)
        raise

    save_state(args, model, tokenizer, global_step)

    return global_step, moving_loss.loss
Example #27
0
 def is_local_master(self) -> bool:
     if is_tpu_available():
         return xm.is_master_ordinal(local=True)
     else:
         return self.args.local_rank in [-1, 0]
Example #28
0
 def summary_write(*args, **kwargs):
     if xm.is_master_ordinal():
         tb_writer.add_scalar(*args, **kwargs)
Example #29
0
def train(train_loader, model, optimizer, scheduler, epoch, args, DEVICE):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model = model.train().to(DEVICE)

    loader = pl.ParallelLoader(train_loader,
                               [DEVICE]).per_device_loader(DEVICE)
    # noise2net = Res2Net(epsilon=0.50, hidden_planes=16, batch_size=args.batch_size).train().to(DEVICE)

    end = time.time()
    for i, (images, target) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        bx = images
        by = target

        print("Zero grad")
        optimizer.zero_grad()

        # with torch.no_grad():
        #     if random.random() < 0.5:
        #         batch_size = bx.shape[0]
        #         noise2net.reload_parameters()
        #         noise2net.set_epsilon(random.uniform(args.noisenet_max_eps / 2.0, args.noisenet_max_eps))
        #         bx = bx.reshape((1, batch_size * 3, 224, 224))
        #         bx = noise2net(bx)
        #         bx = bx.reshape((batch_size, 3, 224, 224))

        print("Forward")
        logits = model(bx)

        print("Cross Entropy")
        loss = F.cross_entropy(logits, by)

        # measure accuracy and record loss
        output, target = logits, by
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        print("Backward")
        loss.backward()

        print("Step")
        xm.optimizer_step(optimizer)

        print("Scheduler step")
        scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 and xm.is_master_ordinal():
            progress.display(i)
Example #30
0
    def run():
        """
        Main function to setup the training loop and evaluation loop.
        See comments for detailed explanation.

        Returns:
            None, but it saves the model weights and model performance, based on the get_map_fn arguments

        """

        # xla will assign a device for each forked run of this function
        device = xm.xla_device()

        # determine if this fork is the master fork to avoid logging and print 8 times
        master = xm.is_master_ordinal()

        if master:
            logger.info("running at batch size %i" % batch_size)

        criterion = nn.CrossEntropyLoss()

        criterion.to(device)
        model = WRAPPED_MODEL.to(device)

        # standard data prep
        CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
        CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

        train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
            ]
        )

        if args.cutout > 0:
            train_transform.transforms.append(Cutout(args.cutout))

        train_data = CifarDataset(transform=train_transform)

        # distributed samples ensure data is sharded to each tpu core
        # if you do not use this, you are only using 1 of the 8 cores
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True,
        )

        train_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=batch_size//xm.xrt_world_size(),
            sampler=train_sampler,
            drop_last=True,
            num_workers=0,
        )

        valid_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
            ]
        )

        valid_data = my_cifar10.CIFAR10(
            root=data_root, train=False, download=False, transform=valid_transform
        )

        valid_sampler = torch.utils.data.distributed.DistributedSampler(
            valid_data,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False,
        )

        valid_queue = torch.utils.data.DataLoader(
            valid_data,
            sampler=valid_sampler,
            batch_size=batch_size//xm.xrt_world_size(),
            drop_last=True,
            num_workers=0,
        )

        # standard optimizer stuff
        parameters = filter(lambda p: p.requires_grad, model.parameters())

        if args.opt == "sgd":

            optimizer = torch.optim.SGD(
                parameters,
                args.learning_rate,
                momentum=momentum,
                weight_decay=args.weight_decay,
            )
        elif args.opt == "lamb":
            optimizer = Lamb(
                parameters, lr=args.learning_rate, weight_decay=weight_decay
            )
        else:
            raise NameError("Unknown Optimizer %s" % args.opt)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(epochs))

        # training by epoch loop
        for epoch in range(epochs):

            # the model needs a droprate, so just assign it
            model.droprate = drop_path_prob * epoch / epochs

            start = datetime.datetime.now()
            st = start.strftime("%Y-%m-%d %H:%M:%S")

            if master:
                logger.info("starting epoch %i at %s" % (epoch, st))

            # parallel loader necessary to load data in parallel to each core
            para_loader = pl.ParallelLoader(train_queue, [device]).per_device_loader(
                device
            )
            correct, train_loss, total = train(
                para_loader, model, criterion, optimizer, params, device
            )

            train_acc = 100 * correct / total

            # collect the train accuracies from all cores
            train_acc = xm.mesh_reduce("avg acc", train_acc, np.mean)

            end = datetime.datetime.now()
            duration = (end - start).total_seconds()

            if master:
                logger.info("train_acc %f duration %f" % (train_acc, duration))

            scheduler.step()

        # validate using 8 cores and collect results
        valid_acc, valid_obj = infer(valid_queue, model, criterion, device)
        valid_acc = xm.mesh_reduce("val avg acc", valid_acc, np.mean)

        if master:
            logger.info("valid_acc %f" % valid_acc)

        # count flops
        _ = add_flops_counting_methods(model)
        model.eval()
        model.start_flops_count()
        random_data = torch.randn(1, 3, 32, 32)
        model(torch.autograd.Variable(random_data).to(device))
        n_flops = np.round(model.compute_average_flops_cost() / 1e6, 4)
        n_flops = xm.mesh_reduce("flops", n_flops, np.mean)

        if master:
            logger.info("flops %f" % n_flops)

        if master:
            logger.info("saving")

        # save weights and results

        xm.save([valid_acc, n_flops], "results.pt")