def get_dataset(name,
                folder,
                split,
                num_fold=10,
                fold_id=0,
                data_aug=False,
                whiten=False,
                div255=False):
    """Gets CIFAR datasets.

  Args:
      name: "cifar-10" or "cifar-100".
      folder: Dataset folder.
      split: "train", "traintrain", "trainval", or "test".

  Returns:
      dp: Dataset object.
  """
    if name == "cifar-10":
        dp = CIFAR10Dataset(folder,
                            split,
                            num_fold=num_fold,
                            fold_id=fold_id,
                            data_aug=data_aug,
                            whiten=whiten,
                            div255=div255)
    elif name == "cifar-100":
        dp = CIFAR100Dataset(folder,
                             split,
                             num_fold=num_fold,
                             fold_id=fold_id,
                             data_aug=data_aug,
                             whiten=whiten,
                             div255=div255)
    else:
        raise Exception("Unknown dataset {}".format(dataset))
    return dp
Exemple #2
0
    return bpd


if __name__ == "__main__":
    opt = parser.parse_args()
    print(opt)

    print("loading dataset")
    if opt.dataset == "imagenet32":
        train_dataset = Imagenet32Dataset(train=not opt.train_on_val,
                                          max_size=1 if opt.debug else -1)
        val_dataset = Imagenet32Dataset(train=0,
                                        max_size=1 if opt.debug else -1)
    else:
        assert opt.dataset == "cifar10"
        train_dataset = CIFAR10Dataset(train=not opt.train_on_val,
                                       max_size=1 if opt.debug else -1)
        val_dataset = CIFAR10Dataset(train=0, max_size=1 if opt.debug else -1)

    print("creating dataloaders")
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )

    print("Len train : {}, val : {}".format(len(train_dataloader),
    config.save_dir = os.path.join(args.work_dir, config.save_dir)
    config.log_dir = os.path.join(args.work_dir, config.log_dir)
    config.resume = args.resume

    # set device
    device = torch.device(config.gpu if torch.cuda.is_available() else 'cpu')

    # get model
    model = get_model(config)
    model.to(device)

    # get data
    df = pd.read_csv(config.df_path)
    train_df = df[df['fold'] != 1]
    val_df = df[df['fold'] == 1]
    train_set = CIFAR10Dataset(train_df, config.img_dir, phase='train')
    val_set = CIFAR10Dataset(val_df, config.img_dir, phase='val')
    train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True, num_workers=config.workers)
    val_loader = DataLoader(val_set, batch_size=config.batch_size, shuffle=False, num_workers=config.workers)

    # get training stuff
    criterion = CrossEntropyLoss().to(device)
    optimizer = SGD(model.parameters(), lr=config.lr, weight_decay=config.weight_decay,
                    momentum=config.momentum, nesterov=config.nesterov)
    scheduler = MultiStepLR(optimizer, gamma=config.gamma, milestones=config.milestones)
    start_epoch = 1
    best_err = 1.1

    # optionally resume from a checkpoint
    if config.resume:
        if os.path.isfile(config.resume):
def test_multi_pass():
    import cifar_exp_config as cifar_conf
    from data import CIFAR10Dataset
    from utils import BatchIterator
    import os
    if os.path.exists("/ais/gobi4"):
        folder = "/ais/gobi4/mren/data/cifar-10"
    else:
        folder = "/home/mren/data/cifar-10"
    data = CIFAR10Dataset(folder=folder, split="valid")
    config = cifar_conf.BaselineConfig()
    b = BatchIterator(data.get_size(),
                      batch_size=8,
                      shuffle=False,
                      cycle=False,
                      get_fn=data.get_batch_idx)

    # Testing the batch iterator.
    b1 = b.next()
    b.reset()
    b2 = b.next()
    np.testing.assert_almost_equal(b1["img"], b2["img"])
    b.reset()
    config.pool_fn = ["avg_pool", "avg_pool", "avg_pool"]

    num_rep = 4
    num_pas = 2
    learn_rate = 1.0
    decimal_tol = 5
    num_elem_dbg = 3
    wlist = [
        "mlp/layer_1/w", "mlp/layer_0/w", "cnn/layer_2/w", "cnn/layer_1/w",
        "cnn/layer_0/w", "mlp/layer_1/b", "mlp/layer_0/b", "cnn/layer_2/b",
        "cnn/layer_1/b", "cnn/layer_0/b"
    ]

    for wname in wlist:
        with log.verbose_level(2):
            ######################################
            # Run the MultiPass model.
            ######################################
            with tf.Graph().as_default():
                s1 = tf.Session()
                with tf.variable_scope("Model"):
                    m1 = MultiPassMultiTowerModel(config,
                                                  num_replica=num_rep,
                                                  num_passes=num_pas)
                tf.set_random_seed(1234)
                s1.run(tf.initialize_all_variables())
                m1.assign_lr(s1, learn_rate)
                batch = b.next()
                with tf.variable_scope("Model", reuse=True):
                    w1 = s1.run(tf.get_variable(wname))
                ce1 = m1.train_step(s1, batch["img"], batch["label"])
                with tf.variable_scope("Model", reuse=True):
                    w1d = s1.run(tf.get_variable(wname))
            b.reset()

            ######################################
            # Run the regular MultiTower model.
            ######################################
            with tf.Graph().as_default():
                s2 = tf.Session()
                with tf.variable_scope("Model2") as scope:
                    m2 = MultiTowerModel(config, num_replica=num_rep)
                tf.set_random_seed(1234)
                s2.run(tf.initialize_all_variables())
                m2.assign_lr(s2, learn_rate)
                with tf.variable_scope("Model2", reuse=True):
                    w2 = s2.run(tf.get_variable(wname))
                ce2 = m2.train_step(s2, batch["img"], batch["label"])
                with tf.variable_scope("Model2", reuse=True):
                    w2d = s2.run(tf.get_variable(wname))
            b.reset()

            ######################################
            # Run the regular model.
            ######################################
            with tf.Graph().as_default():
                s3 = tf.Session()
                with tf.variable_scope("Model3") as scope:
                    m3 = CNNModel(config)
                tf.set_random_seed(1234)
                s3.run(tf.initialize_all_variables())
                m3.assign_lr(s3, learn_rate)
                with tf.variable_scope("Model3", reuse=True):
                    w3 = s3.run(tf.get_variable(wname))
                ce3 = m3.train_step(s3, batch["img"], batch["label"])
                with tf.variable_scope("Model3", reuse=True):
                    w3d = s3.run(tf.get_variable(wname))
            b.reset()

        # Make this block one indent level to avoid logging.
        ######################################
        # Make sure the weights are the same.
        ######################################
        log.info("Testing {}".format(wname))
        print_w("w1", w1, num_elem_dbg)
        print_w("w2", w2, num_elem_dbg)
        print_w("w3", w3, num_elem_dbg)
        np.testing.assert_almost_equal(w1, w2, decimal=decimal_tol)
        np.testing.assert_almost_equal(w2, w3, decimal=decimal_tol)

        ######################################
        # Make sure the gradients are the same.
        ######################################
        print_w("w1 delta", w1d - w1, num_elem_dbg)
        print_w("w2 delta", w2d - w2, num_elem_dbg)
        print_w("w3 delta", w3d - w3, num_elem_dbg)
        print_w("w1 new", w1d, num_elem_dbg)
        print_w("w2 new", w2d, num_elem_dbg)
        print_w("w3 new", w3d, num_elem_dbg)

        np.testing.assert_almost_equal(get_diff_signature(w1, w1d),
                                       get_diff_signature(w2, w2d),
                                       decimal=decimal_tol)
        np.testing.assert_almost_equal(get_diff_signature(w2, w2d),
                                       get_diff_signature(w3, w3d),
                                       decimal=decimal_tol)
        log.info("Success")
def main(args=None):
    if args:
        opt = parser.parse_args(args)
    else:
        opt = parser.parse_args()

    print(opt)

    print("loading dataset")
    if opt.dataset == "imagenet32":
        train_dataset = Imagenet32DatasetDiscrete(
            train=not opt.train_on_val,
            max_size=1 if opt.debug else opt.train_size)
        val_dataset = Imagenet32DatasetDiscrete(
            train=0,
            max_size=1 if opt.debug else opt.val_size,
            start_idx=opt.val_start_idx)
    else:
        assert opt.dataset == "cifar10"
        train_dataset = CIFAR10Dataset(train=not opt.train_on_val,
                                       max_size=1 if opt.debug else -1)
        val_dataset = CIFAR10Dataset(train=0, max_size=1 if opt.debug else -1)

    print("creating dataloaders")
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )

    print("Len train : {}, val : {}".format(len(train_dataloader),
                                            len(val_dataloader)))

    device = torch.device("cuda") if (
        torch.cuda.is_available() and opt.use_cuda) else torch.device("cpu")
    print("Device is {}".format(device))

    print("Loading models on device...")

    # Initialize embedder
    if opt.conditioning == 'unconditional':
        encoder = UnconditionalClassEmbedding()
    elif opt.conditioning == "bert":
        encoder = BERTEncoder()
    else:
        assert opt.conditioning == "one-hot"
        encoder = OneHotClassEmbedding(train_dataset.n_classes)

    # generative_model = ConditionalPixelCNNpp(embd_size=encoder.embed_size, img_shape=train_dataset.image_shape,
    #                                          nr_resnet=opt.n_resnet, nr_filters=opt.n_filters,
    #                                          nr_logistic_mix=3 if train_dataset.image_shape[0] == 1 else 10)

    generative_model = FlowPlusPlus(
        scales=[(0, 4), (2, 3)],
        # in_shape=(3, 32, 32),
        in_shape=train_dataset.image_shape,
        mid_channels=opt.n_filters,
        num_blocks=opt.num_blocks,
        num_dequant_blocks=opt.num_dequant_blocks,
        num_components=opt.num_components,
        use_attn=opt.use_attn,
        drop_prob=opt.drop_prob,
        condition_embd_size=encoder.embed_size)

    generative_model = generative_model.to(device)
    encoder = encoder.to(device)
    print("Models loaded on device")

    # Configure data loader

    print("dataloaders loaded")
    # Optimizers
    # optimizer = torch.optim.Adam(generative_model.parameters(), lr=opt.lr)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=opt.lr_decay)
    param_groups = util.get_param_groups(generative_model,
                                         opt.lr_decay,
                                         norm_suffix='weight_g')
    optimizer = torch.optim.Adam(param_groups, lr=opt.lr)
    warm_up = opt.warm_up * opt.batch_size
    scheduler = lr_scheduler.LambdaLR(optimizer,
                                      lambda s: min(1., s / warm_up))
    # create output directory

    os.makedirs(os.path.join(opt.output_dir, "models"), exist_ok=True)
    os.makedirs(os.path.join(opt.output_dir, "tensorboard"), exist_ok=True)
    writer = SummaryWriter(log_dir=os.path.join(opt.output_dir, "tensorboard"))

    global global_step
    global_step = 0

    # ----------
    #  Training
    # ----------
    if opt.train:
        train(model=generative_model,
              embedder=encoder,
              optimizer=optimizer,
              scheduler=scheduler,
              train_loader=train_dataloader,
              val_loader=val_dataloader,
              opt=opt,
              writer=writer,
              device=device)
    else:
        assert opt.model_checkpoint is not None, 'no model checkpoint specified'
        print("Loading model from state dict...")
        load_model(opt.model_checkpoint, generative_model)
        print("Model loaded.")
        sample_images_full(generative_model,
                           encoder,
                           opt.output_dir,
                           dataloader=val_dataloader,
                           device=device)
        eval(model=generative_model,
             embedder=encoder,
             test_loader=val_dataloader,
             opt=opt,
             writer=writer,
             device=device)
def main(args=None):
    if args:
        opt = parser.parse_args(args)
    else:
        opt = parser.parse_args()

    print(opt)

    print("loading dataset")
    if opt.dataset == "imagenet32":
        train_dataset = Imagenet32Dataset(
            train=not opt.train_on_val,
            max_size=1 if opt.debug else opt.train_size)
        val_dataset = Imagenet32Dataset(
            train=0,
            max_size=1 if opt.debug else opt.val_size,
            start_idx=opt.val_start_idx)
    else:
        assert opt.dataset == "cifar10"
        train_dataset = CIFAR10Dataset(train=not opt.train_on_val,
                                       max_size=1 if opt.debug else -1)
        val_dataset = CIFAR10Dataset(train=0, max_size=1 if opt.debug else -1)

    print("creating dataloaders")
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )

    print("Len train : {}, val : {}".format(len(train_dataloader),
                                            len(val_dataloader)))

    device = torch.device("cuda") if (
        torch.cuda.is_available() and opt.use_cuda) else torch.device("cpu")
    print("Device is {}".format(device))

    print("Loading models on device...")

    # Initialize embedder
    if opt.conditioning == 'unconditional':
        encoder = UnconditionalClassEmbedding()
    elif opt.conditioning == "bert":
        encoder = BERTEncoder()
    else:
        assert opt.conditioning == "one-hot"
        encoder = OneHotClassEmbedding(train_dataset.n_classes)

    generative_model = ConditionalPixelCNNpp(
        embd_size=encoder.embed_size,
        img_shape=train_dataset.image_shape,
        nr_resnet=opt.n_resnet,
        nr_filters=opt.n_filters,
        nr_logistic_mix=3 if train_dataset.image_shape[0] == 1 else 10)

    generative_model = generative_model.to(device)
    encoder = encoder.to(device)
    print("Models loaded on device")

    # Configure data loader

    print("dataloaders loaded")
    # Optimizers
    optimizer = torch.optim.Adam(generative_model.parameters(), lr=opt.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=opt.lr_decay)
    # create output directory

    os.makedirs(os.path.join(opt.output_dir, "models"), exist_ok=True)
    os.makedirs(os.path.join(opt.output_dir, "tensorboard"), exist_ok=True)
    writer = SummaryWriter(log_dir=os.path.join(opt.output_dir, "tensorboard"))

    # ----------
    #  Training
    # ----------
    if opt.train:
        train(model=generative_model,
              embedder=encoder,
              optimizer=optimizer,
              scheduler=scheduler,
              train_loader=train_dataloader,
              val_loader=val_dataloader,
              opt=opt,
              writer=writer,
              device=device)
    else:
        assert opt.model_checkpoint is not None, 'no model checkpoint specified'
        print("Loading model from state dict...")
        load_model(opt.model_checkpoint, generative_model)
        print("Model loaded.")
        sample_images_full(generative_model,
                           encoder,
                           opt.output_dir,
                           dataloader=val_dataloader,
                           device=device)
        eval(model=generative_model,
             embedder=encoder,
             test_loader=val_dataloader,
             opt=opt,
             writer=writer,
             device=device)
Exemple #7
0
print(opt)

################# make output dirs #####################
os.makedirs(os.path.join(opt.output_dir, "models"), exist_ok=True)
os.makedirs(os.path.join(opt.output_dir, "samples"), exist_ok=True)
os.makedirs(os.path.join(opt.output_dir, "tensorboard"), exist_ok=True)

writer = SummaryWriter(log_dir=os.path.join(opt.output_dir, "tensorboard"), comment='Cifar10')

################# load data #####################
print("loading dataset")
if opt.dataset == "imagenet":
    train_dataset = Imagenet32Dataset(train=True, max_size=1 if opt.debug else -1)
    val_dataset = Imagenet32Dataset(train=0, max_size=1 if opt.debug else -1)
elif opt.dataset == "cifar10":
    train_dataset = CIFAR10Dataset(train=True, max_size=1 if opt.debug else -1)
#    train_dataset = dset.CIFAR10(
#        root=opt.dataroot, download=True,
#        transform=transforms.Compose([
#            transforms.Scale(opt.imageSize),
#            transforms.ToTensor(),
#            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
#        ]))
    val_dataset = CIFAR10Dataset(train=0, max_size=1 if opt.debug else -1)
elif opt.dataset == "coco":
    print("INFO: using coco")
    path2data="/home/ooo/Data/train2017"
    path2json="/home/ooo/Data/annotations_trainval2017/annotations/captions_train2017.json"
    train_dataset = dset.CocoCaptions(
        root=path2data, annFile=path2json,
        transform=transforms.Compose([