def on_train_batch_begin(self, batch, logs=None):
     lr = bit_hyperrule.get_lr(self.step, self.num_samples, self.base_lr)
     tf.keras.backend.set_value(self.model.optimizer.lr, lr)
     self.step += 1
예제 #2
0
def main(args):
    logger = bit_common.setup_logger(args)

    # Lets cuDNN benchmark conv implementations and choose the fastest.
    # Only good if sizes stay the same within the main loop!
    torch.backends.cudnn.benchmark = True

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logger.info(f"Going to train on {device}")

    train_set, valid_set, train_loader, valid_loader = mktrainval(args, logger)

    logger.info(f"Loading model from {args.model}.npz")
    model = models.KNOWN_MODELS[args.model](head_size=len(valid_set.classes),
                                            zero_head=True)
    model.load_from(np.load(f"{args.model}.npz"))

    logger.info("Moving model onto all GPUs")
    model = torch.nn.DataParallel(model)

    # Optionally resume from a checkpoint.
    # Load it to CPU first as we'll move the model to GPU later.
    # This way, we save a little bit of GPU memory when loading.
    step = 0

    # Note: no weight-decay!
    optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)

    # Resume fine-tuning if we find a saved model.
    savename = pjoin(args.logdir, args.name, "bit.pth.tar")
    try:
        logger.info(f"Model will be saved in '{savename}'")
        checkpoint = torch.load(savename, map_location="cpu")
        logger.info(f"Found saved model to resume from at '{savename}'")

        step = checkpoint["step"]
        model.load_state_dict(checkpoint["model"])
        optim.load_state_dict(checkpoint["optim"])
        logger.info(f"Resumed at step {step}")
    except FileNotFoundError:
        logger.info("Fine-tuning from BiT")

    model = model.to(device)
    optim.zero_grad()

    model.train()
    mixup = bit_hyperrule.get_mixup(len(train_set))
    cri = torch.nn.CrossEntropyLoss().to(device)

    logger.info("Starting training!")
    chrono = lb.Chrono()
    accum_steps = 0
    mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1
    end = time.time()

    with lb.Uninterrupt() as u:
        for x, y in recycle(train_loader):
            # measure data loading time, which is spent in the `for` statement.
            chrono._done("load", time.time() - end)

            if u.interrupted:
                break

            # Schedule sending to GPU(s)
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            # Update learning-rate, including stop training if over.
            lr = bit_hyperrule.get_lr(step, len(train_set), args.base_lr)
            if lr is None:
                break
            for param_group in optim.param_groups:
                param_group["lr"] = lr

            if mixup > 0.0:
                x, y_a, y_b = mixup_data(x, y, mixup_l)

            # compute output
            with chrono.measure("fprop"):
                logits = model(x)
                if mixup > 0.0:
                    c = mixup_criterion(cri, logits, y_a, y_b, mixup_l)
                else:
                    c = cri(logits, y)
                c_num = float(
                    c.data.cpu().numpy())  # Also ensures a sync point.

            # Accumulate grads
            with chrono.measure("grads"):
                (c / args.batch_split).backward()
                accum_steps += 1

            accstep = f" ({accum_steps}/{args.batch_split})" if args.batch_split > 1 else ""
            logger.info(
                f"[step {step}{accstep}]: loss={c_num:.5f} (lr={lr:.1e})")  # pylint: disable=logging-format-interpolation
            logger.flush()

            # Update params
            if accum_steps == args.batch_split:
                with chrono.measure("update"):
                    optim.step()
                    optim.zero_grad()
                step += 1
                accum_steps = 0
                # Sample new mixup ratio for next batch
                mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1

                # Run evaluation and save the model.
                if args.eval_every and step % args.eval_every == 0:
                    run_eval(model, valid_loader, device, chrono, logger, step)
                    if args.save:
                        torch.save(
                            {
                                "step": step,
                                "model": model.state_dict(),
                                "optim": optim.state_dict(),
                            }, savename)

            end = time.time()

        # Final eval at end of training.
        run_eval(model, valid_loader, device, chrono, logger, step='end')

    logger.info(f"Timings:\n{chrono}")
예제 #3
0
def main(args):
    logger = bit_common.setup_logger(args)

    # Lets cuDNN benchmark conv implementations and choose the fastest.
    # Only good if sizes stay the same within the main loop!
    torch.backends.cudnn.benchmark = True

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logger.info("Going to train on {}".format(device))

    train_set, valid_set, train_loader, valid_loader = mktrainval(args, logger)

    logger.info("Loading model from {}.npz".format(args.model))
    model = models.KNOWN_MODELS[args.model](head_size=len(valid_set.classes),
                                            zero_head=True)
    model.load_from(np.load("{}.npz".format(args.model)))

    logger.info("Moving model onto all GPUs")
    model = torch.nn.DataParallel(model)

    # Note: no weight-decay!
    optim = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=0.9)

    # Optionally resume from a checkpoint.
    # Load it to CPU first as we'll move the model to GPU later.
    # This way, we save a little bit of GPU memory when loading.
    step = 0

    # If pretrained weights are specified
    if args.weights_path:
        logger.info("Loading weights from {}".format(args.weights_path))
        checkpoint = torch.load(args.weights_path, map_location="cpu")
        # New task might have different classes; remove the pretrained classifier weights
        del checkpoint['model']['module.head.conv.weight']
        del checkpoint['model']['module.head.conv.bias']
        model.load_state_dict(checkpoint["model"], strict=False)

    # Resume fine-tuning if we find a saved model.
    savename = pjoin(args.logdir, args.name, "bit.pth.tar")
    try:
        logger.info("Model will be saved in '{}'".format(savename))
        checkpoint = torch.load(savename, map_location="cpu")
        logger.info(
            "Found saved model to resume from at '{}'".format(savename))

        step = checkpoint["step"]
        model.load_state_dict(checkpoint["model"])
        optim.load_state_dict(checkpoint["optim"])
        logger.info("Resumed at step {}".format(step))

    except FileNotFoundError:
        logger.info("Fine-tuning from BiT")

    model = model.to(device)
    # Send to GPU
    optimizer_to(optim, device)
    optim.zero_grad()

    model.train()
    mixup = bit_hyperrule.get_mixup(len(train_set))
    cri = torch.nn.CrossEntropyLoss().to(device)

    logger.info("Starting training!")
    chrono = lb.Chrono()

    mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1
    end = time.time()

    with lb.Uninterrupt() as u:
        for x, y in recycle(train_loader):
            # measure data loading time, which is spent in the `for` statement.
            chrono._done("load", time.time() - end)

            if u.interrupted:
                break

            # Schedule sending to GPU(s)
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            # Update learning-rate, including stop training if over.
            lr = bit_hyperrule.get_lr(step, len(train_set), args.base_lr)
            if lr is None:
                break
            for param_group in optim.param_groups:
                param_group["lr"] = lr

            if mixup > 0.0:
                x, y_a, y_b = mixup_data(x, y, mixup_l)

            # compute output
            logits = model(x)
            if mixup > 0.0:
                c = mixup_criterion(cri, logits, y_a, y_b, mixup_l)
            else:
                c = cri(logits, y)
            c_num = float(c.data.cpu().numpy())  # Also ensures a sync point.

            # Accumulate grads
            (c / args.batch_split).backward()

            logger.info("[step {}]: loss={:.5f} (lr={:.1e})".format(
                step, c_num, lr))
            logger.flush()

            # Update params
            optim.step()
            optim.zero_grad()
            step += 1

            # Sample new mixup ratio for next batch
            mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1

            end = time.time()
            if step % 50 == 0:
                torch.save(
                    {
                        "step": step,
                        "model": model.state_dict(),
                        "optim": optim.state_dict(),
                    }, savename)

        # Final eval at end of training.
        run_eval(model, valid_loader, device, chrono, logger, step='end')

    logger.info("Timings:\n{}".format(chrono))
예제 #4
0
def main(args):
    logger = bit_common.setup_logger(args)

    logger.info(f'Available devices: {jax.devices()}')

    model = models.KNOWN_MODELS[args.model]

    # Load weigths of a BiT model
    bit_model_file = os.path.join(args.bit_pretrained_dir, f'{args.model}.npz')
    if not os.path.exists(bit_model_file):
        raise FileNotFoundError(
            f'Model file is not found in "{args.bit_pretrained_dir}" directory.'
        )
    with open(bit_model_file, 'rb') as f:
        params_tf = np.load(f)
        params_tf = dict(zip(params_tf.keys(), params_tf.values()))

    resize_size, crop_size = bit_hyperrule.get_resolution_from_dataset(
        args.dataset)

    # Setup input pipeline
    dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train',
                                                   args.examples_per_class)

    data_train = input_pipeline.get_data(
        dataset=args.dataset,
        mode='train',
        repeats=None,
        batch_size=args.batch,
        resize_size=resize_size,
        crop_size=crop_size,
        examples_per_class=args.examples_per_class,
        examples_per_class_seed=args.examples_per_class_seed,
        mixup_alpha=bit_hyperrule.get_mixup(dataset_info['num_examples']),
        num_devices=jax.local_device_count(),
        tfds_manual_dir=args.tfds_manual_dir)
    logger.info(data_train)
    data_test = input_pipeline.get_data(dataset=args.dataset,
                                        mode='test',
                                        repeats=1,
                                        batch_size=args.batch_eval,
                                        resize_size=resize_size,
                                        crop_size=crop_size,
                                        examples_per_class=None,
                                        examples_per_class_seed=0,
                                        mixup_alpha=None,
                                        num_devices=jax.local_device_count(),
                                        tfds_manual_dir=args.tfds_manual_dir)
    logger.info(data_test)

    # Build ResNet architecture
    ResNet = model.partial(num_classes=dataset_info['num_classes'])
    _, params = ResNet.init_by_shape(
        jax.random.PRNGKey(0), [([1, crop_size, crop_size, 3], jnp.float32)])
    resnet_fn = ResNet.call

    # pmap replicates the models over all GPUs
    resnet_fn_repl = jax.pmap(ResNet.call)

    def cross_entropy_loss(*, logits, labels):
        logp = jax.nn.log_softmax(logits)
        return -jnp.mean(jnp.sum(logp * labels, axis=1))

    def loss_fn(params, images, labels):
        logits = resnet_fn(params, images)
        return cross_entropy_loss(logits=logits, labels=labels)

    # Update step, replicated over all GPUs
    @partial(jax.pmap, axis_name='batch')
    def update_fn(opt, lr, batch):
        l, g = jax.value_and_grad(loss_fn)(opt.target, batch['image'],
                                           batch['label'])
        g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g)
        opt = opt.apply_gradient(g, learning_rate=lr)
        return opt

    # In-place update of randomly initialized weights by BiT weigths
    tf2jax.transform_params(params,
                            params_tf,
                            num_classes=dataset_info['num_classes'])

    # Create optimizer and replicate it over all GPUs
    opt = optim.Momentum(beta=0.9).create(params)
    opt_repl = flax_utils.replicate(opt)

    # Delete referenes to the objects that are not needed anymore
    del opt
    del params

    total_steps = bit_hyperrule.get_schedule(dataset_info['num_examples'])[-1]

    # Run training loop
    for step, batch in zip(range(1, total_steps + 1),
                           data_train.as_numpy_iterator()):
        lr = bit_hyperrule.get_lr(step - 1, dataset_info['num_examples'],
                                  args.base_lr)
        opt_repl = update_fn(opt_repl, flax_utils.replicate(lr), batch)

        # Run eval step
        if ((args.eval_every and step % args.eval_every == 0)
                or (step == total_steps)):

            accuracy_test = np.mean([
                c for batch in data_test.as_numpy_iterator()
                for c in (np.argmax(
                    resnet_fn_repl(opt_repl.target, batch['image']), axis=2) ==
                          np.argmax(batch['label'], axis=2)).ravel()
            ])

            logger.info(f'Step: {step}, '
                        f'learning rate: {lr:.07f}, '
                        f'Test accuracy: {accuracy_test:0.3f}')
예제 #5
0
def main(args):
  logger = bit_common.setup_logger(args)
  if args.test_run:
    args.batch = 8
    args.batch_split = 1
    args.workers = 1

  logger.info("Args: " + str(args))

  # Fix seed
  # torch.manual_seed(args.seed)
  # torch.backends.cudnn.deterministic = True
  # torch.backends.cudnn.benchmark = False
  # np.random.seed(args.seed)
  # random.seed(args.seed)

  # Speed up
  torch.backends.cudnn.banchmark = True

  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  logger.info(f"Going to train on {device}")

  n_train, n_classes, train_loader, valid_loader = mktrainval(args, logger)

  if args.inpaint != 'none':
    if args.inpaint == 'mean':
      inpaint_model = (lambda x, mask: x*mask)
    elif args.inpaint == 'random':
      inpaint_model = RandomColorWithNoiseInpainter((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    elif args.inpaint == 'local':
      inpaint_model = LocalMeanInpainter(window=)
    elif args.inpaint == 'cagan':
      inpaint_model = CAInpainter(
        valid_loader.batch_size, checkpoint_dir='./inpainting_models/release_imagenet_256/')
    else:
      raise NotImplementedError(f"Unkown inpaint {args.inpaint}")


  logger.info(f"Training {args.model}")
  if args.model in models.KNOWN_MODELS:
    model = models.KNOWN_MODELS[args.model](head_size=n_classes, zero_head=True)
  else: # from torchvision
    model = getattr(torchvision.models, args.model)(pretrained=args.finetune)

  # Resume fine-tuning if we find a saved model.
  step = 0

  # Optionally resume from a checkpoint.
  # Load it to CPU first as we'll move the model to GPU later.
  # This way, we save a little bit of GPU memory when loading.
  savename = pjoin(args.logdir, args.name, "model.pt")
  try:
    logger.info(f"Model will be saved in '{savename}'")
    checkpoint = torch.load(savename, map_location="cpu")
    logger.info(f"Found saved model to resume from at '{savename}'")

    step = checkpoint["step"]
    model.load_state_dict(checkpoint["model"])
    model = model.to(device)

    # Note: no weight-decay!
    optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
    optim.load_state_dict(checkpoint["optim"])
    logger.info(f"Resumed at step {step}")
  except FileNotFoundError:
    if args.finetune:
      logger.info("Fine-tuning from BiT")
      model.load_from(np.load(f"models/{args.model}.npz"))

    model = model.to(device)
    optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)

  if args.fp16:
    model, optim = amp.initialize(model, optim, opt_level="O1")

  logger.info("Moving model onto all GPUs")
  model = torch.nn.DataParallel(model)

  optim.zero_grad()

  model.train()
  mixup = 0
  if args.mixup:
    mixup = bit_hyperrule.get_mixup(n_train)

  cri = torch.nn.CrossEntropyLoss().to(device)
  def counterfact_cri(logit, y):
    if torch.all(y >= 0):
      return F.cross_entropy(logit, y, reduction='mean')

    loss1 = F.cross_entropy(logit[y >= 0], y[y >= 0], reduction='sum')

    cf_logit, cf_y = logit[y < 0], -(y[y < 0] + 1)

    # Implement my own logsumexp trick
    m, _ = torch.max(cf_logit, dim=1, keepdim=True)
    exp_logit = torch.exp(cf_logit - m)
    sum_exp_logit = torch.sum(exp_logit, dim=1)

    eps = 1e-20
    num = (sum_exp_logit - exp_logit[torch.arange(exp_logit.shape[0]), cf_y])
    num = torch.log(num + eps)
    denon = torch.log(sum_exp_logit + eps)

    # Negative log probability
    loss2 = -(num - denon).sum()
    return (loss1 + loss2) / y.shape[0]

  logger.info("Starting training!")
  chrono = lb.Chrono()
  accum_steps = 0
  mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1
  end = time.time()

  with lb.Uninterrupt() as u:
    for x, y in recycle(train_loader):
      # measure data loading time, which is spent in the `for` statement.
      chrono._done("load", time.time() - end)

      if u.interrupted:
        break

      # Handle inpainting
      if not isinstance(x, Sample) or x.bbox == [None] * len(x.bbox):
        criteron = cri
      else:
        criteron = counterfact_cri

        bboxes = x.bbox
        x = x.img

        # is_bbox_exists = x.new_ones(x.shape[0], dtype=torch.bool)
        mask = x.new_ones(x.shape[0], 1, *x.shape[2:])
        for i, bbox in enumerate(bboxes):
          for coord_x, coord_y, w, h in zip(bbox.xs, bbox.ys, bbox.ws, bbox.hs):
            mask[i, 0, coord_y:(coord_y + h), coord_x:(coord_x + w)] = 0.

        impute_x = inpaint_model(x, mask)
        impute_y = (-y - 1)

        x = torch.cat([x, impute_x], dim=0)
        # label -1 as negative of class 0, -2 as negative of class 1 etc...
        y = torch.cat([y, impute_y], dim=0)

      # Schedule sending to GPU(s)
      x = x.to(device, non_blocking=True)
      y = y.to(device, non_blocking=True)

      # Update learning-rate, including stop training if over.
      lr = bit_hyperrule.get_lr(step, n_train, args.base_lr)
      if lr is None:
        break
      for param_group in optim.param_groups:
        param_group["lr"] = lr

      if mixup > 0.0:
        x, y_a, y_b = mixup_data(x, y, mixup_l)

      # compute output
      with chrono.measure("fprop"):
        logits = model(x)
        if mixup > 0.0:
          c = mixup_criterion(criteron, logits, y_a, y_b, mixup_l)
        else:
          c = criteron(logits, y)
        c_num = float(c.data.cpu().numpy())  # Also ensures a sync point.

      # Accumulate grads
      with chrono.measure("grads"):
        loss = (c / args.batch_split)
        if args.fp16:
          with amp.scale_loss(loss, optim) as scaled_loss:
            scaled_loss.backward()
        else:
          loss.backward()
        accum_steps += 1

      accstep = f" ({accum_steps}/{args.batch_split})" if args.batch_split > 1 else ""
      logger.info(f"[step {step}{accstep}]: loss={c_num:.5f} (lr={lr:.1e})")  # pylint: disable=logging-format-interpolation
      logger.flush()

      # Update params
      if accum_steps == args.batch_split:
        with chrono.measure("update"):
          optim.step()
          optim.zero_grad()
        step += 1
        accum_steps = 0
        # Sample new mixup ratio for next batch
        mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1

        # Run evaluation and save the model.
        if args.eval_every and step % args.eval_every == 0:
          run_eval(model, valid_loader, device, chrono, logger, step)
          if args.save:
            torch.save({
                "step": step,
                "model": model.module.state_dict(),
                "optim": optim.state_dict(),
            }, savename)

      end = time.time()

    # Save model!!
    if args.save:
      torch.save({
        "step": step,
        "model": model.module.state_dict(),
        "optim": optim.state_dict(),
      }, savename)

    json.dump({
      'model': args.model,
      'head_size': n_classes,
      'inpaint': args.inpaint,
      'dataset': args.dataset,
    }, open(pjoin(args.logdir, args.name, 'hyperparams.json'), 'w'))

    # Final eval at end of training.
    run_eval(model, valid_loader, device, chrono, logger, step='end')


  logger.info(f"Timings:\n{chrono}")
예제 #6
0
def main():

    torch.backends.cudnn.benchmark = True

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    train_set, valid_set, train_loader, valid_loader = mktrainval()

    model = models.KNOWN_MODELS["BiT-M-R50x1"](head_size=len(
        valid_set.classes),
                                               zero_head=True)
    model.load_from(np.load(f"BiT-M-R50x1.npz"))

    model = torch.nn.DataParallel(model)
    step = 0

    optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)

    savename = pjoin("./log", "cifar10", "bit.pth.tar")
    try:
        checkpoint = torch.load(savename, map_location="cpu")
        step = checkpoint["step"]
        model.load_state_dict(checkpoint["model"])
        optim.load_state_dict(checkpoint["optim"])

    except FileNotFoundError:
        print('model not fount')
    model = model.to(device)
    optim.zero_grad()

    model.train()
    mixup = bit_hyperrule.get_mixup(len(train_set))
    cri = torch.nn.CrossEntropyLoss().to(device)

    mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1
    all_top1 = []
    all_loss = []
    all_val_loss = []
    for x, y in recycle(train_loader):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        lr = bit_hyperrule.get_lr(step, len(train_set), 0.003)
        if lr is None:
            break
        for param_group in optim.param_groups:
            param_group["lr"] = lr

        if mixup > 0.0:
            x, y_a, y_b = mixup_data(x, y, mixup_l)

        logits = model(x)
        if mixup > 0.0:
            c = mixup_criterion(cri, logits, y_a, y_b, mixup_l)
        else:
            c = cri(logits, y)
        c_num = float(c.data.cpu().numpy())

        c.backward()
        top1, _, _1 = topk(logits, y, ks=(1, 5))
        all_top1.extend(top1)
        print(
            f"[step {step}]: loss={c_num:.5f} ,accu={np.mean(all_top1):.2%} (lr={lr:.1e})"
        )
        all_loss.append(c_num)
        all_top1 = []
        optim.step()
        optim.zero_grad()
        step += 1

        mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1

        val_loss = run_eval(model, valid_loader, device, step)
        all_val_loss.append(val_loss)

        if step % 10 == 0:
            torch.save(
                {
                    "step": step,
                    "model": model.state_dict(),
                    "optim": optim.state_dict(),
                }, savename)
            print("model save to" + savename)
    plt.figure(figsize=(8, 8))
    plt.plot(range(1, 11), all_loss, label='Training loss')
    plt.plot(range(1, 11), all_val_loss, label='Validation loss')
    plt.legend(loc='lower right')
    plt.title(' loss and step')
    plt.show()