Exemplo n.º 1
0
def main(args):
    logger = bit_common.setup_logger(args)

    # Lets cuDNN benchmark conv implementations and choose the fastest.
    # Only good if sizes stay the same within the main loop!
    torch.backends.cudnn.benchmark = True

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logger.info("Going to train on {}".format(device))

    train_set, valid_set, train_loader, valid_loader = mktrainval(args, logger)

    logger.info("Loading model from {}.npz".format(args.model))
    model = models.KNOWN_MODELS[args.model](head_size=len(valid_set.classes),
                                            zero_head=True)
    model.load_from(np.load("{}.npz".format(args.model)))

    logger.info("Moving model onto all GPUs")
    model = torch.nn.DataParallel(model)

    # Note: no weight-decay!
    optim = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=0.9)

    # Optionally resume from a checkpoint.
    # Load it to CPU first as we'll move the model to GPU later.
    # This way, we save a little bit of GPU memory when loading.
    step = 0

    # If pretrained weights are specified
    if args.weights_path:
        logger.info("Loading weights from {}".format(args.weights_path))
        checkpoint = torch.load(args.weights_path, map_location="cpu")
        # New task might have different classes; remove the pretrained classifier weights
        del checkpoint['model']['module.head.conv.weight']
        del checkpoint['model']['module.head.conv.bias']
        model.load_state_dict(checkpoint["model"], strict=False)

    # Resume fine-tuning if we find a saved model.
    savename = pjoin(args.logdir, args.name, "bit.pth.tar")
    try:
        logger.info("Model will be saved in '{}'".format(savename))
        checkpoint = torch.load(savename, map_location="cpu")
        logger.info(
            "Found saved model to resume from at '{}'".format(savename))

        step = checkpoint["step"]
        model.load_state_dict(checkpoint["model"])
        optim.load_state_dict(checkpoint["optim"])
        logger.info("Resumed at step {}".format(step))

    except FileNotFoundError:
        logger.info("Fine-tuning from BiT")

    model = model.to(device)
    # Send to GPU
    optimizer_to(optim, device)
    optim.zero_grad()

    model.train()
    mixup = bit_hyperrule.get_mixup(len(train_set))
    cri = torch.nn.CrossEntropyLoss().to(device)

    logger.info("Starting training!")
    chrono = lb.Chrono()

    mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1
    end = time.time()

    with lb.Uninterrupt() as u:
        for x, y in recycle(train_loader):
            # measure data loading time, which is spent in the `for` statement.
            chrono._done("load", time.time() - end)

            if u.interrupted:
                break

            # Schedule sending to GPU(s)
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            # Update learning-rate, including stop training if over.
            lr = bit_hyperrule.get_lr(step, len(train_set), args.base_lr)
            if lr is None:
                break
            for param_group in optim.param_groups:
                param_group["lr"] = lr

            if mixup > 0.0:
                x, y_a, y_b = mixup_data(x, y, mixup_l)

            # compute output
            logits = model(x)
            if mixup > 0.0:
                c = mixup_criterion(cri, logits, y_a, y_b, mixup_l)
            else:
                c = cri(logits, y)
            c_num = float(c.data.cpu().numpy())  # Also ensures a sync point.

            # Accumulate grads
            (c / args.batch_split).backward()

            logger.info("[step {}]: loss={:.5f} (lr={:.1e})".format(
                step, c_num, lr))
            logger.flush()

            # Update params
            optim.step()
            optim.zero_grad()
            step += 1

            # Sample new mixup ratio for next batch
            mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1

            end = time.time()
            if step % 50 == 0:
                torch.save(
                    {
                        "step": step,
                        "model": model.state_dict(),
                        "optim": optim.state_dict(),
                    }, savename)

        # Final eval at end of training.
        run_eval(model, valid_loader, device, chrono, logger, step='end')

    logger.info("Timings:\n{}".format(chrono))
Exemplo n.º 2
0
def main(args):
    logger = bit_common.setup_logger(args)

    # Lets cuDNN benchmark conv implementations and choose the fastest.
    # Only good if sizes stay the same within the main loop!
    torch.backends.cudnn.benchmark = True

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logger.info(f"Going to train on {device}")

    train_set, valid_set, train_loader, valid_loader = mktrainval(args, logger)

    logger.info(f"Loading model from {args.model}.npz")
    model = models.KNOWN_MODELS[args.model](head_size=len(valid_set.classes),
                                            zero_head=True)
    model.load_from(np.load(f"{args.model}.npz"))

    logger.info("Moving model onto all GPUs")
    model = torch.nn.DataParallel(model)

    # Optionally resume from a checkpoint.
    # Load it to CPU first as we'll move the model to GPU later.
    # This way, we save a little bit of GPU memory when loading.
    step = 0

    # Note: no weight-decay!
    optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)

    # Resume fine-tuning if we find a saved model.
    savename = pjoin(args.logdir, args.name, "bit.pth.tar")
    try:
        logger.info(f"Model will be saved in '{savename}'")
        checkpoint = torch.load(savename, map_location="cpu")
        logger.info(f"Found saved model to resume from at '{savename}'")

        step = checkpoint["step"]
        model.load_state_dict(checkpoint["model"])
        optim.load_state_dict(checkpoint["optim"])
        logger.info(f"Resumed at step {step}")
    except FileNotFoundError:
        logger.info("Fine-tuning from BiT")

    model = model.to(device)
    optim.zero_grad()

    model.train()
    mixup = bit_hyperrule.get_mixup(len(train_set))
    cri = torch.nn.CrossEntropyLoss().to(device)

    logger.info("Starting training!")
    chrono = lb.Chrono()
    accum_steps = 0
    mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1
    end = time.time()

    with lb.Uninterrupt() as u:
        for x, y in recycle(train_loader):
            # measure data loading time, which is spent in the `for` statement.
            chrono._done("load", time.time() - end)

            if u.interrupted:
                break

            # Schedule sending to GPU(s)
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            # Update learning-rate, including stop training if over.
            lr = bit_hyperrule.get_lr(step, len(train_set), args.base_lr)
            if lr is None:
                break
            for param_group in optim.param_groups:
                param_group["lr"] = lr

            if mixup > 0.0:
                x, y_a, y_b = mixup_data(x, y, mixup_l)

            # compute output
            with chrono.measure("fprop"):
                logits = model(x)
                if mixup > 0.0:
                    c = mixup_criterion(cri, logits, y_a, y_b, mixup_l)
                else:
                    c = cri(logits, y)
                c_num = float(
                    c.data.cpu().numpy())  # Also ensures a sync point.

            # Accumulate grads
            with chrono.measure("grads"):
                (c / args.batch_split).backward()
                accum_steps += 1

            accstep = f" ({accum_steps}/{args.batch_split})" if args.batch_split > 1 else ""
            logger.info(
                f"[step {step}{accstep}]: loss={c_num:.5f} (lr={lr:.1e})")  # pylint: disable=logging-format-interpolation
            logger.flush()

            # Update params
            if accum_steps == args.batch_split:
                with chrono.measure("update"):
                    optim.step()
                    optim.zero_grad()
                step += 1
                accum_steps = 0
                # Sample new mixup ratio for next batch
                mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1

                # Run evaluation and save the model.
                if args.eval_every and step % args.eval_every == 0:
                    run_eval(model, valid_loader, device, chrono, logger, step)
                    if args.save:
                        torch.save(
                            {
                                "step": step,
                                "model": model.state_dict(),
                                "optim": optim.state_dict(),
                            }, savename)

            end = time.time()

        # Final eval at end of training.
        run_eval(model, valid_loader, device, chrono, logger, step='end')

    logger.info(f"Timings:\n{chrono}")
Exemplo n.º 3
0
def main(args):
  logger = bit_common.setup_logger(args)
  if args.test_run:
    args.batch = 8
    args.batch_split = 1
    args.workers = 1

  logger.info("Args: " + str(args))

  # Fix seed
  # torch.manual_seed(args.seed)
  # torch.backends.cudnn.deterministic = True
  # torch.backends.cudnn.benchmark = False
  # np.random.seed(args.seed)
  # random.seed(args.seed)

  # Speed up
  torch.backends.cudnn.banchmark = True

  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  logger.info(f"Going to train on {device}")

  n_train, n_classes, train_loader, valid_loader = mktrainval(args, logger)

  if args.inpaint != 'none':
    if args.inpaint == 'mean':
      inpaint_model = (lambda x, mask: x*mask)
    elif args.inpaint == 'random':
      inpaint_model = RandomColorWithNoiseInpainter((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    elif args.inpaint == 'local':
      inpaint_model = LocalMeanInpainter(window=)
    elif args.inpaint == 'cagan':
      inpaint_model = CAInpainter(
        valid_loader.batch_size, checkpoint_dir='./inpainting_models/release_imagenet_256/')
    else:
      raise NotImplementedError(f"Unkown inpaint {args.inpaint}")


  logger.info(f"Training {args.model}")
  if args.model in models.KNOWN_MODELS:
    model = models.KNOWN_MODELS[args.model](head_size=n_classes, zero_head=True)
  else: # from torchvision
    model = getattr(torchvision.models, args.model)(pretrained=args.finetune)

  # Resume fine-tuning if we find a saved model.
  step = 0

  # Optionally resume from a checkpoint.
  # Load it to CPU first as we'll move the model to GPU later.
  # This way, we save a little bit of GPU memory when loading.
  savename = pjoin(args.logdir, args.name, "model.pt")
  try:
    logger.info(f"Model will be saved in '{savename}'")
    checkpoint = torch.load(savename, map_location="cpu")
    logger.info(f"Found saved model to resume from at '{savename}'")

    step = checkpoint["step"]
    model.load_state_dict(checkpoint["model"])
    model = model.to(device)

    # Note: no weight-decay!
    optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
    optim.load_state_dict(checkpoint["optim"])
    logger.info(f"Resumed at step {step}")
  except FileNotFoundError:
    if args.finetune:
      logger.info("Fine-tuning from BiT")
      model.load_from(np.load(f"models/{args.model}.npz"))

    model = model.to(device)
    optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)

  if args.fp16:
    model, optim = amp.initialize(model, optim, opt_level="O1")

  logger.info("Moving model onto all GPUs")
  model = torch.nn.DataParallel(model)

  optim.zero_grad()

  model.train()
  mixup = 0
  if args.mixup:
    mixup = bit_hyperrule.get_mixup(n_train)

  cri = torch.nn.CrossEntropyLoss().to(device)
  def counterfact_cri(logit, y):
    if torch.all(y >= 0):
      return F.cross_entropy(logit, y, reduction='mean')

    loss1 = F.cross_entropy(logit[y >= 0], y[y >= 0], reduction='sum')

    cf_logit, cf_y = logit[y < 0], -(y[y < 0] + 1)

    # Implement my own logsumexp trick
    m, _ = torch.max(cf_logit, dim=1, keepdim=True)
    exp_logit = torch.exp(cf_logit - m)
    sum_exp_logit = torch.sum(exp_logit, dim=1)

    eps = 1e-20
    num = (sum_exp_logit - exp_logit[torch.arange(exp_logit.shape[0]), cf_y])
    num = torch.log(num + eps)
    denon = torch.log(sum_exp_logit + eps)

    # Negative log probability
    loss2 = -(num - denon).sum()
    return (loss1 + loss2) / y.shape[0]

  logger.info("Starting training!")
  chrono = lb.Chrono()
  accum_steps = 0
  mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1
  end = time.time()

  with lb.Uninterrupt() as u:
    for x, y in recycle(train_loader):
      # measure data loading time, which is spent in the `for` statement.
      chrono._done("load", time.time() - end)

      if u.interrupted:
        break

      # Handle inpainting
      if not isinstance(x, Sample) or x.bbox == [None] * len(x.bbox):
        criteron = cri
      else:
        criteron = counterfact_cri

        bboxes = x.bbox
        x = x.img

        # is_bbox_exists = x.new_ones(x.shape[0], dtype=torch.bool)
        mask = x.new_ones(x.shape[0], 1, *x.shape[2:])
        for i, bbox in enumerate(bboxes):
          for coord_x, coord_y, w, h in zip(bbox.xs, bbox.ys, bbox.ws, bbox.hs):
            mask[i, 0, coord_y:(coord_y + h), coord_x:(coord_x + w)] = 0.

        impute_x = inpaint_model(x, mask)
        impute_y = (-y - 1)

        x = torch.cat([x, impute_x], dim=0)
        # label -1 as negative of class 0, -2 as negative of class 1 etc...
        y = torch.cat([y, impute_y], dim=0)

      # Schedule sending to GPU(s)
      x = x.to(device, non_blocking=True)
      y = y.to(device, non_blocking=True)

      # Update learning-rate, including stop training if over.
      lr = bit_hyperrule.get_lr(step, n_train, args.base_lr)
      if lr is None:
        break
      for param_group in optim.param_groups:
        param_group["lr"] = lr

      if mixup > 0.0:
        x, y_a, y_b = mixup_data(x, y, mixup_l)

      # compute output
      with chrono.measure("fprop"):
        logits = model(x)
        if mixup > 0.0:
          c = mixup_criterion(criteron, logits, y_a, y_b, mixup_l)
        else:
          c = criteron(logits, y)
        c_num = float(c.data.cpu().numpy())  # Also ensures a sync point.

      # Accumulate grads
      with chrono.measure("grads"):
        loss = (c / args.batch_split)
        if args.fp16:
          with amp.scale_loss(loss, optim) as scaled_loss:
            scaled_loss.backward()
        else:
          loss.backward()
        accum_steps += 1

      accstep = f" ({accum_steps}/{args.batch_split})" if args.batch_split > 1 else ""
      logger.info(f"[step {step}{accstep}]: loss={c_num:.5f} (lr={lr:.1e})")  # pylint: disable=logging-format-interpolation
      logger.flush()

      # Update params
      if accum_steps == args.batch_split:
        with chrono.measure("update"):
          optim.step()
          optim.zero_grad()
        step += 1
        accum_steps = 0
        # Sample new mixup ratio for next batch
        mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1

        # Run evaluation and save the model.
        if args.eval_every and step % args.eval_every == 0:
          run_eval(model, valid_loader, device, chrono, logger, step)
          if args.save:
            torch.save({
                "step": step,
                "model": model.module.state_dict(),
                "optim": optim.state_dict(),
            }, savename)

      end = time.time()

    # Save model!!
    if args.save:
      torch.save({
        "step": step,
        "model": model.module.state_dict(),
        "optim": optim.state_dict(),
      }, savename)

    json.dump({
      'model': args.model,
      'head_size': n_classes,
      'inpaint': args.inpaint,
      'dataset': args.dataset,
    }, open(pjoin(args.logdir, args.name, 'hyperparams.json'), 'w'))

    # Final eval at end of training.
    run_eval(model, valid_loader, device, chrono, logger, step='end')


  logger.info(f"Timings:\n{chrono}")