コード例 #1
0
def main(args):
    print(f"Args save is {args.save}")
    logger = bit_common.setup_logger(args)
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda:0")
    assert torch.cuda.is_available()
    chrono = lb.Chrono()
    logger.info(f"Validating")
    valid_set, valid_loader, classes = mkval(args)
    logger.info(f"Loading model from {args.model}.npz")
    model = models.KNOWN_MODELS[args.model](head_size=len(valid_set.classes),
                                            zero_head=True)
    model = torch.nn.DataParallel(model)
    step = 0
    optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
    savename = pjoin('models', "rotation_augmentated.pth.tar")
    checkpoint = torch.load(savename, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
    model = model.to(device)
    run_eval(model, valid_loader, device, chrono, logger, classes)
コード例 #2
0
def main(args):
    tf.io.gfile.makedirs(args.logdir)
    logger = bit_common.setup_logger(args)

    logger.info(f'Available devices: {tf.config.list_physical_devices()}')

    tf.io.gfile.makedirs(args.bit_pretrained_dir)
    bit_model_file = os.path.join(args.bit_pretrained_dir, f'{args.model}.h5')
    if not tf.io.gfile.exists(bit_model_file):
        model_url = models.KNOWN_MODELS[args.model]
        logger.info(f'Downloading the model from {model_url}...')
        tf.io.gfile.copy(model_url, bit_model_file)

    # Set up input pipeline
    dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train',
                                                   args.examples_per_class)

    # Distribute training
    strategy = tf.distribute.MirroredStrategy()
    num_devices = strategy.num_replicas_in_sync
    print('Number of devices: {}'.format(num_devices))

    resize_size, crop_size = bit_hyperrule.get_resolution_from_dataset(
        args.dataset)
    data_train = input_pipeline.get_data(
        dataset=args.dataset,
        mode='train',
        repeats=None,
        batch_size=args.batch,
        resize_size=resize_size,
        crop_size=crop_size,
        examples_per_class=args.examples_per_class,
        examples_per_class_seed=args.examples_per_class_seed,
        mixup_alpha=bit_hyperrule.get_mixup(dataset_info['num_examples']),
        num_devices=num_devices,
        tfds_manual_dir=args.tfds_manual_dir)
    data_test = input_pipeline.get_data(dataset=args.dataset,
                                        mode='test',
                                        repeats=1,
                                        batch_size=args.batch,
                                        resize_size=resize_size,
                                        crop_size=crop_size,
                                        examples_per_class=1,
                                        examples_per_class_seed=0,
                                        mixup_alpha=None,
                                        num_devices=num_devices,
                                        tfds_manual_dir=args.tfds_manual_dir)

    data_train = data_train.map(lambda x: reshape_for_keras(
        x, batch_size=args.batch, crop_size=crop_size))
    data_test = data_test.map(lambda x: reshape_for_keras(
        x, batch_size=args.batch, crop_size=crop_size))

    with strategy.scope():
        filters_factor = int(args.model[-1]) * 4
        model = models.ResnetV2(num_units=models.NUM_UNITS[args.model],
                                num_outputs=21843,
                                filters_factor=filters_factor,
                                name="resnet",
                                trainable=True,
                                dtype=tf.float32)

        model.build((None, None, None, 3))
        logger.info(f'Loading weights...')
        model.load_weights(bit_model_file)
        logger.info(f'Weights loaded into model!')

        model._head = tf.keras.layers.Dense(units=dataset_info['num_classes'],
                                            use_bias=True,
                                            kernel_initializer="zeros",
                                            trainable=True,
                                            name="head/dense")

        lr_supports = bit_hyperrule.get_schedule(dataset_info['num_examples'])

        schedule_length = lr_supports[-1]
        # NOTE: Let's not do that unless verified necessary and we do the same
        # across all three codebases.
        # schedule_length = schedule_length * 512 / args.batch

        optimizer = tf.keras.optimizers.SGD(momentum=0.9)
        loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

        model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy'])

    logger.info(f'Fine-tuning the model...')
    steps_per_epoch = args.eval_every or schedule_length
    history = model.fit(
        data_train,
        steps_per_epoch=steps_per_epoch,
        epochs=schedule_length // steps_per_epoch,
        validation_data=data_test,  # here we are only using
        # this data to evaluate our performance
        callbacks=[BiTLRSched(args.base_lr, dataset_info['num_examples'])],
    )

    for epoch, accu in enumerate(history.history['val_accuracy']):
        logger.info(f'Step: {epoch * args.eval_every}, '
                    f'Test accuracy: {accu:0.3f}')
コード例 #3
0
ファイル: train.py プロジェクト: cicorias/big_transfer
def main(args):
    logger = bit_common.setup_logger(args)

    # Lets cuDNN benchmark conv implementations and choose the fastest.
    # Only good if sizes stay the same within the main loop!
    torch.backends.cudnn.benchmark = True

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logger.info(f"Going to train on {device}")

    train_set, valid_set, train_loader, valid_loader = mktrainval(args, logger)

    logger.info(f"Loading model from {args.model}.npz")
    model = models.KNOWN_MODELS[args.model](head_size=len(valid_set.classes),
                                            zero_head=True)
    model.load_from(np.load(f"{args.model}.npz"))

    logger.info("Moving model onto all GPUs")
    model = torch.nn.DataParallel(model)

    # Optionally resume from a checkpoint.
    # Load it to CPU first as we'll move the model to GPU later.
    # This way, we save a little bit of GPU memory when loading.
    step = 0

    # Note: no weight-decay!
    optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)

    # Resume fine-tuning if we find a saved model.
    savename = pjoin(args.logdir, args.name, "bit.pth.tar")
    try:
        logger.info(f"Model will be saved in '{savename}'")
        checkpoint = torch.load(savename, map_location="cpu")
        logger.info(f"Found saved model to resume from at '{savename}'")

        step = checkpoint["step"]
        model.load_state_dict(checkpoint["model"])
        optim.load_state_dict(checkpoint["optim"])
        logger.info(f"Resumed at step {step}")
    except FileNotFoundError:
        logger.info("Fine-tuning from BiT")

    model = model.to(device)
    optim.zero_grad()

    model.train()
    mixup = bit_hyperrule.get_mixup(len(train_set))
    cri = torch.nn.CrossEntropyLoss().to(device)

    logger.info("Starting training!")
    chrono = lb.Chrono()
    accum_steps = 0
    mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1
    end = time.time()

    with lb.Uninterrupt() as u:
        for x, y in recycle(train_loader):
            # measure data loading time, which is spent in the `for` statement.
            chrono._done("load", time.time() - end)

            if u.interrupted:
                break

            # Schedule sending to GPU(s)
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            # Update learning-rate, including stop training if over.
            lr = bit_hyperrule.get_lr(step, len(train_set), args.base_lr)
            if lr is None:
                break
            for param_group in optim.param_groups:
                param_group["lr"] = lr

            if mixup > 0.0:
                x, y_a, y_b = mixup_data(x, y, mixup_l)

            # compute output
            with chrono.measure("fprop"):
                logits = model(x)
                if mixup > 0.0:
                    c = mixup_criterion(cri, logits, y_a, y_b, mixup_l)
                else:
                    c = cri(logits, y)
                c_num = float(
                    c.data.cpu().numpy())  # Also ensures a sync point.

            # Accumulate grads
            with chrono.measure("grads"):
                (c / args.batch_split).backward()
                accum_steps += 1

            accstep = f" ({accum_steps}/{args.batch_split})" if args.batch_split > 1 else ""
            logger.info(
                f"[step {step}{accstep}]: loss={c_num:.5f} (lr={lr:.1e})")  # pylint: disable=logging-format-interpolation
            logger.flush()

            # Update params
            if accum_steps == args.batch_split:
                with chrono.measure("update"):
                    optim.step()
                    optim.zero_grad()
                step += 1
                accum_steps = 0
                # Sample new mixup ratio for next batch
                mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1

                # Run evaluation and save the model.
                if args.eval_every and step % args.eval_every == 0:
                    run_eval(model, valid_loader, device, chrono, logger, step)
                    if args.save:
                        torch.save(
                            {
                                "step": step,
                                "model": model.state_dict(),
                                "optim": optim.state_dict(),
                            }, savename)

            end = time.time()

        # Final eval at end of training.
        run_eval(model, valid_loader, device, chrono, logger, step='end')

    logger.info(f"Timings:\n{chrono}")
コード例 #4
0
def main(args):
    logger = bit_common.setup_logger(args)

    # Lets cuDNN benchmark conv implementations and choose the fastest.
    # Only good if sizes stay the same within the main loop!
    torch.backends.cudnn.benchmark = True

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logger.info("Going to train on {}".format(device))

    train_set, valid_set, train_loader, valid_loader = mktrainval(args, logger)

    logger.info("Loading model from {}.npz".format(args.model))
    model = models.KNOWN_MODELS[args.model](head_size=len(valid_set.classes),
                                            zero_head=True)
    model.load_from(np.load("{}.npz".format(args.model)))

    logger.info("Moving model onto all GPUs")
    model = torch.nn.DataParallel(model)

    # Note: no weight-decay!
    optim = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=0.9)

    # Optionally resume from a checkpoint.
    # Load it to CPU first as we'll move the model to GPU later.
    # This way, we save a little bit of GPU memory when loading.
    step = 0

    # If pretrained weights are specified
    if args.weights_path:
        logger.info("Loading weights from {}".format(args.weights_path))
        checkpoint = torch.load(args.weights_path, map_location="cpu")
        # New task might have different classes; remove the pretrained classifier weights
        del checkpoint['model']['module.head.conv.weight']
        del checkpoint['model']['module.head.conv.bias']
        model.load_state_dict(checkpoint["model"], strict=False)

    # Resume fine-tuning if we find a saved model.
    savename = pjoin(args.logdir, args.name, "bit.pth.tar")
    try:
        logger.info("Model will be saved in '{}'".format(savename))
        checkpoint = torch.load(savename, map_location="cpu")
        logger.info(
            "Found saved model to resume from at '{}'".format(savename))

        step = checkpoint["step"]
        model.load_state_dict(checkpoint["model"])
        optim.load_state_dict(checkpoint["optim"])
        logger.info("Resumed at step {}".format(step))

    except FileNotFoundError:
        logger.info("Fine-tuning from BiT")

    model = model.to(device)
    # Send to GPU
    optimizer_to(optim, device)
    optim.zero_grad()

    model.train()
    mixup = bit_hyperrule.get_mixup(len(train_set))
    cri = torch.nn.CrossEntropyLoss().to(device)

    logger.info("Starting training!")
    chrono = lb.Chrono()

    mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1
    end = time.time()

    with lb.Uninterrupt() as u:
        for x, y in recycle(train_loader):
            # measure data loading time, which is spent in the `for` statement.
            chrono._done("load", time.time() - end)

            if u.interrupted:
                break

            # Schedule sending to GPU(s)
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            # Update learning-rate, including stop training if over.
            lr = bit_hyperrule.get_lr(step, len(train_set), args.base_lr)
            if lr is None:
                break
            for param_group in optim.param_groups:
                param_group["lr"] = lr

            if mixup > 0.0:
                x, y_a, y_b = mixup_data(x, y, mixup_l)

            # compute output
            logits = model(x)
            if mixup > 0.0:
                c = mixup_criterion(cri, logits, y_a, y_b, mixup_l)
            else:
                c = cri(logits, y)
            c_num = float(c.data.cpu().numpy())  # Also ensures a sync point.

            # Accumulate grads
            (c / args.batch_split).backward()

            logger.info("[step {}]: loss={:.5f} (lr={:.1e})".format(
                step, c_num, lr))
            logger.flush()

            # Update params
            optim.step()
            optim.zero_grad()
            step += 1

            # Sample new mixup ratio for next batch
            mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1

            end = time.time()
            if step % 50 == 0:
                torch.save(
                    {
                        "step": step,
                        "model": model.state_dict(),
                        "optim": optim.state_dict(),
                    }, savename)

        # Final eval at end of training.
        run_eval(model, valid_loader, device, chrono, logger, step='end')

    logger.info("Timings:\n{}".format(chrono))
コード例 #5
0
def main(args):

    best_acc = -1

    logger = bit_common.setup_logger(args)
    cp, cn = smooth_BCE(eps=0.1)
    # Lets cuDNN benchmark conv implementations and choose the fastest.
    # Only good if sizes stay the same within the main loop!
    torch.backends.cudnn.benchmark = True

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Going to train on {device}")

    classes = 5

    train_set, valid_set, train_loader, valid_loader = mktrainval(args, logger)
    logger.info(f"Loading model from {args.model}.npz")
    #model = models.KNOWN_MODELS[args.model](head_size=classes, zero_head=True)
    #model.load_from(np.load(f"{args.model}.npz"))

    model = EfficientNet.from_pretrained(args.model, num_classes=classes)
    logger.info("Moving model onto all GPUs")
    model = torch.nn.DataParallel(model)

    # Optionally resume from a checkpoint.
    # Load it to CPU first as we'll move the model to GPU later.
    # This way, we save a little bit of GPU memory when loading.
    start_epoch = 0

    # Note: no weight-decay!
    optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)

    # Resume fine-tuning if we find a saved model.
    savename = pjoin(args.logdir, args.name, "bit.pth.tar")
    try:
        logger.info(f"Model will be saved in '{savename}'")
        checkpoint = torch.load(savename, map_location="cpu")
        logger.info(f"Found saved model to resume from at '{savename}'")

        start_epoch = checkpoint["epoch"]
        model.load_state_dict(checkpoint["model"])
        optim.load_state_dict(checkpoint["optim"])
        logger.info(f"Resumed at epoch {start_epoch}")
    except FileNotFoundError:
        logger.info("Fine-tuning from BiT")

    model = model.to(device)
    optim.zero_grad()

    model.train()
    mixup = bit_hyperrule.get_mixup(len(train_set))
    #mixup = -1
    cri = torch.nn.CrossEntropyLoss().to(device)
    #cri = FocalLoss(cri)
    logger.info("Starting training!")
    chrono = lb.Chrono()
    accum_steps = 0
    mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1
    end = time.time()

    epoches = 10
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optim,
                                                    max_lr=0.01,
                                                    steps_per_epoch=1,
                                                    epochs=epoches)

    with lb.Uninterrupt() as u:
        for epoch in range(start_epoch, epoches):

            pbar = enumerate(train_loader)
            pbar = tqdm.tqdm(pbar, total=len(train_loader))

            scheduler.step()
            all_top1, all_top5 = [], []
            for param_group in optim.param_groups:
                lr = param_group["lr"]
            #for x, y in recycle(train_loader):
            for batch_id, (x, y) in pbar:
                #for batch_id, (x, y) in enumerate(train_loader):
                # measure data loading time, which is spent in the `for` statement.
                chrono._done("load", time.time() - end)

                if u.interrupted:
                    break

                # Schedule sending to GPU(s)
                x = x.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)

                # Update learning-rate, including stop training if over.
                #lr = bit_hyperrule.get_lr(step, len(train_set), args.base_lr)
                #if lr is None:
                #  break

                if mixup > 0.0:
                    x, y_a, y_b = mixup_data(x, y, mixup_l)

                # compute output
                with chrono.measure("fprop"):
                    logits = model(x)
                    top1, top5 = topk(logits, y, ks=(1, 5))
                    all_top1.extend(top1.cpu())
                    all_top5.extend(top5.cpu())
                    if mixup > 0.0:
                        c = mixup_criterion(cri, logits, y_a, y_b, mixup_l)
                    else:
                        c = cri(logits, y)
                train_loss = c.item()
                train_acc = np.mean(all_top1) * 100.0
                # Accumulate grads
                with chrono.measure("grads"):
                    (c / args.batch_split).backward()
                    accum_steps += 1
                accstep = f"({accum_steps}/{args.batch_split})" if args.batch_split > 1 else ""
                s = f"epoch={epoch} batch {batch_id}{accstep}: loss={train_loss:.5f} train_acc={train_acc:.2f} lr={lr:.1e}"
                #s = f"epoch={epoch} batch {batch_id}{accstep}: loss={c.item():.5f} lr={lr:.1e}"
                pbar.set_description(s)
                #logger.info(f"[batch {batch_id}{accstep}]: loss={c_num:.5f} (lr={lr:.1e})")  # pylint: disable=logging-format-interpolation
                logger.flush()

                # Update params
                with chrono.measure("update"):
                    optim.step()
                    optim.zero_grad()
                # Sample new mixup ratio for next batch
                mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1

            # Run evaluation and save the model.
            val_loss, val_acc = run_eval(model, valid_loader, device, chrono,
                                         logger, epoch)

            best = val_acc > best_acc
            if best:
                best_acc = val_acc
                torch.save(
                    {
                        "epoch": epoch,
                        "val_loss": val_loss,
                        "val_acc": val_acc,
                        "train_acc": train_acc,
                        "model": model.state_dict(),
                        "optim": optim.state_dict(),
                    }, savename)
            end = time.time()

    logger.info(f"Timings:\n{chrono}")
コード例 #6
0
ファイル: train.py プロジェクト: AaronPHD/big_transfer
def main(args):
    logger = bit_common.setup_logger(args)

    logger.info(f'Available devices: {jax.devices()}')

    model = models.KNOWN_MODELS[args.model]

    # Load weigths of a BiT model
    bit_model_file = os.path.join(args.bit_pretrained_dir, f'{args.model}.npz')
    if not os.path.exists(bit_model_file):
        raise FileNotFoundError(
            f'Model file is not found in "{args.bit_pretrained_dir}" directory.'
        )
    with open(bit_model_file, 'rb') as f:
        params_tf = np.load(f)
        params_tf = dict(zip(params_tf.keys(), params_tf.values()))

    resize_size, crop_size = bit_hyperrule.get_resolution_from_dataset(
        args.dataset)

    # Setup input pipeline
    dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train',
                                                   args.examples_per_class)

    data_train = input_pipeline.get_data(
        dataset=args.dataset,
        mode='train',
        repeats=None,
        batch_size=args.batch,
        resize_size=resize_size,
        crop_size=crop_size,
        examples_per_class=args.examples_per_class,
        examples_per_class_seed=args.examples_per_class_seed,
        mixup_alpha=bit_hyperrule.get_mixup(dataset_info['num_examples']),
        num_devices=jax.local_device_count(),
        tfds_manual_dir=args.tfds_manual_dir)
    logger.info(data_train)
    data_test = input_pipeline.get_data(dataset=args.dataset,
                                        mode='test',
                                        repeats=1,
                                        batch_size=args.batch_eval,
                                        resize_size=resize_size,
                                        crop_size=crop_size,
                                        examples_per_class=None,
                                        examples_per_class_seed=0,
                                        mixup_alpha=None,
                                        num_devices=jax.local_device_count(),
                                        tfds_manual_dir=args.tfds_manual_dir)
    logger.info(data_test)

    # Build ResNet architecture
    ResNet = model.partial(num_classes=dataset_info['num_classes'])
    _, params = ResNet.init_by_shape(
        jax.random.PRNGKey(0), [([1, crop_size, crop_size, 3], jnp.float32)])
    resnet_fn = ResNet.call

    # pmap replicates the models over all GPUs
    resnet_fn_repl = jax.pmap(ResNet.call)

    def cross_entropy_loss(*, logits, labels):
        logp = jax.nn.log_softmax(logits)
        return -jnp.mean(jnp.sum(logp * labels, axis=1))

    def loss_fn(params, images, labels):
        logits = resnet_fn(params, images)
        return cross_entropy_loss(logits=logits, labels=labels)

    # Update step, replicated over all GPUs
    @partial(jax.pmap, axis_name='batch')
    def update_fn(opt, lr, batch):
        l, g = jax.value_and_grad(loss_fn)(opt.target, batch['image'],
                                           batch['label'])
        g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g)
        opt = opt.apply_gradient(g, learning_rate=lr)
        return opt

    # In-place update of randomly initialized weights by BiT weigths
    tf2jax.transform_params(params,
                            params_tf,
                            num_classes=dataset_info['num_classes'])

    # Create optimizer and replicate it over all GPUs
    opt = optim.Momentum(beta=0.9).create(params)
    opt_repl = flax_utils.replicate(opt)

    # Delete referenes to the objects that are not needed anymore
    del opt
    del params

    total_steps = bit_hyperrule.get_schedule(dataset_info['num_examples'])[-1]

    # Run training loop
    for step, batch in zip(range(1, total_steps + 1),
                           data_train.as_numpy_iterator()):
        lr = bit_hyperrule.get_lr(step - 1, dataset_info['num_examples'],
                                  args.base_lr)
        opt_repl = update_fn(opt_repl, flax_utils.replicate(lr), batch)

        # Run eval step
        if ((args.eval_every and step % args.eval_every == 0)
                or (step == total_steps)):

            accuracy_test = np.mean([
                c for batch in data_test.as_numpy_iterator()
                for c in (np.argmax(
                    resnet_fn_repl(opt_repl.target, batch['image']), axis=2) ==
                          np.argmax(batch['label'], axis=2)).ravel()
            ])

            logger.info(f'Step: {step}, '
                        f'learning rate: {lr:.07f}, '
                        f'Test accuracy: {accuracy_test:0.3f}')
コード例 #7
0
def main(args):
  logger = bit_common.setup_logger(args)
  if args.test_run:
    args.batch = 8
    args.batch_split = 1
    args.workers = 1

  logger.info("Args: " + str(args))

  # Fix seed
  # torch.manual_seed(args.seed)
  # torch.backends.cudnn.deterministic = True
  # torch.backends.cudnn.benchmark = False
  # np.random.seed(args.seed)
  # random.seed(args.seed)

  # Speed up
  torch.backends.cudnn.banchmark = True

  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  logger.info(f"Going to train on {device}")

  n_train, n_classes, train_loader, valid_loader = mktrainval(args, logger)

  if args.inpaint != 'none':
    if args.inpaint == 'mean':
      inpaint_model = (lambda x, mask: x*mask)
    elif args.inpaint == 'random':
      inpaint_model = RandomColorWithNoiseInpainter((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    elif args.inpaint == 'local':
      inpaint_model = LocalMeanInpainter(window=)
    elif args.inpaint == 'cagan':
      inpaint_model = CAInpainter(
        valid_loader.batch_size, checkpoint_dir='./inpainting_models/release_imagenet_256/')
    else:
      raise NotImplementedError(f"Unkown inpaint {args.inpaint}")


  logger.info(f"Training {args.model}")
  if args.model in models.KNOWN_MODELS:
    model = models.KNOWN_MODELS[args.model](head_size=n_classes, zero_head=True)
  else: # from torchvision
    model = getattr(torchvision.models, args.model)(pretrained=args.finetune)

  # Resume fine-tuning if we find a saved model.
  step = 0

  # Optionally resume from a checkpoint.
  # Load it to CPU first as we'll move the model to GPU later.
  # This way, we save a little bit of GPU memory when loading.
  savename = pjoin(args.logdir, args.name, "model.pt")
  try:
    logger.info(f"Model will be saved in '{savename}'")
    checkpoint = torch.load(savename, map_location="cpu")
    logger.info(f"Found saved model to resume from at '{savename}'")

    step = checkpoint["step"]
    model.load_state_dict(checkpoint["model"])
    model = model.to(device)

    # Note: no weight-decay!
    optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
    optim.load_state_dict(checkpoint["optim"])
    logger.info(f"Resumed at step {step}")
  except FileNotFoundError:
    if args.finetune:
      logger.info("Fine-tuning from BiT")
      model.load_from(np.load(f"models/{args.model}.npz"))

    model = model.to(device)
    optim = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)

  if args.fp16:
    model, optim = amp.initialize(model, optim, opt_level="O1")

  logger.info("Moving model onto all GPUs")
  model = torch.nn.DataParallel(model)

  optim.zero_grad()

  model.train()
  mixup = 0
  if args.mixup:
    mixup = bit_hyperrule.get_mixup(n_train)

  cri = torch.nn.CrossEntropyLoss().to(device)
  def counterfact_cri(logit, y):
    if torch.all(y >= 0):
      return F.cross_entropy(logit, y, reduction='mean')

    loss1 = F.cross_entropy(logit[y >= 0], y[y >= 0], reduction='sum')

    cf_logit, cf_y = logit[y < 0], -(y[y < 0] + 1)

    # Implement my own logsumexp trick
    m, _ = torch.max(cf_logit, dim=1, keepdim=True)
    exp_logit = torch.exp(cf_logit - m)
    sum_exp_logit = torch.sum(exp_logit, dim=1)

    eps = 1e-20
    num = (sum_exp_logit - exp_logit[torch.arange(exp_logit.shape[0]), cf_y])
    num = torch.log(num + eps)
    denon = torch.log(sum_exp_logit + eps)

    # Negative log probability
    loss2 = -(num - denon).sum()
    return (loss1 + loss2) / y.shape[0]

  logger.info("Starting training!")
  chrono = lb.Chrono()
  accum_steps = 0
  mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1
  end = time.time()

  with lb.Uninterrupt() as u:
    for x, y in recycle(train_loader):
      # measure data loading time, which is spent in the `for` statement.
      chrono._done("load", time.time() - end)

      if u.interrupted:
        break

      # Handle inpainting
      if not isinstance(x, Sample) or x.bbox == [None] * len(x.bbox):
        criteron = cri
      else:
        criteron = counterfact_cri

        bboxes = x.bbox
        x = x.img

        # is_bbox_exists = x.new_ones(x.shape[0], dtype=torch.bool)
        mask = x.new_ones(x.shape[0], 1, *x.shape[2:])
        for i, bbox in enumerate(bboxes):
          for coord_x, coord_y, w, h in zip(bbox.xs, bbox.ys, bbox.ws, bbox.hs):
            mask[i, 0, coord_y:(coord_y + h), coord_x:(coord_x + w)] = 0.

        impute_x = inpaint_model(x, mask)
        impute_y = (-y - 1)

        x = torch.cat([x, impute_x], dim=0)
        # label -1 as negative of class 0, -2 as negative of class 1 etc...
        y = torch.cat([y, impute_y], dim=0)

      # Schedule sending to GPU(s)
      x = x.to(device, non_blocking=True)
      y = y.to(device, non_blocking=True)

      # Update learning-rate, including stop training if over.
      lr = bit_hyperrule.get_lr(step, n_train, args.base_lr)
      if lr is None:
        break
      for param_group in optim.param_groups:
        param_group["lr"] = lr

      if mixup > 0.0:
        x, y_a, y_b = mixup_data(x, y, mixup_l)

      # compute output
      with chrono.measure("fprop"):
        logits = model(x)
        if mixup > 0.0:
          c = mixup_criterion(criteron, logits, y_a, y_b, mixup_l)
        else:
          c = criteron(logits, y)
        c_num = float(c.data.cpu().numpy())  # Also ensures a sync point.

      # Accumulate grads
      with chrono.measure("grads"):
        loss = (c / args.batch_split)
        if args.fp16:
          with amp.scale_loss(loss, optim) as scaled_loss:
            scaled_loss.backward()
        else:
          loss.backward()
        accum_steps += 1

      accstep = f" ({accum_steps}/{args.batch_split})" if args.batch_split > 1 else ""
      logger.info(f"[step {step}{accstep}]: loss={c_num:.5f} (lr={lr:.1e})")  # pylint: disable=logging-format-interpolation
      logger.flush()

      # Update params
      if accum_steps == args.batch_split:
        with chrono.measure("update"):
          optim.step()
          optim.zero_grad()
        step += 1
        accum_steps = 0
        # Sample new mixup ratio for next batch
        mixup_l = np.random.beta(mixup, mixup) if mixup > 0 else 1

        # Run evaluation and save the model.
        if args.eval_every and step % args.eval_every == 0:
          run_eval(model, valid_loader, device, chrono, logger, step)
          if args.save:
            torch.save({
                "step": step,
                "model": model.module.state_dict(),
                "optim": optim.state_dict(),
            }, savename)

      end = time.time()

    # Save model!!
    if args.save:
      torch.save({
        "step": step,
        "model": model.module.state_dict(),
        "optim": optim.state_dict(),
      }, savename)

    json.dump({
      'model': args.model,
      'head_size': n_classes,
      'inpaint': args.inpaint,
      'dataset': args.dataset,
    }, open(pjoin(args.logdir, args.name, 'hyperparams.json'), 'w'))

    # Final eval at end of training.
    run_eval(model, valid_loader, device, chrono, logger, step='end')


  logger.info(f"Timings:\n{chrono}")