Example #1
0
assert 0 <= params.smooth_label < 0.5
assert not params.ae_reload or os.path.isfile(params.ae_reload)
assert not params.lat_dis_reload or os.path.isfile(params.lat_dis_reload)
assert not params.ptc_dis_reload or os.path.isfile(params.ptc_dis_reload)
assert not params.clf_dis_reload or os.path.isfile(params.clf_dis_reload)
#print params.eval_clf
#print "==="
assert os.path.isfile(params.eval_clf)
assert params.lambda_lat_dis == 0 or params.n_lat_dis > 0
assert params.lambda_ptc_dis == 0 or params.n_ptc_dis > 0
assert params.lambda_clf_dis == 0 or params.n_clf_dis > 0

# initialize experiment / load dataset
logger = initialize_exp(params)
data, attributes, data2, attributes2 = load_images(params)
train_data = DataSampler(data[0], attributes[0], data2, attributes2, params)
valid_data = DataSampler(data[1], attributes[1], None, None, params)

# build the model
ae = AutoEncoder(params).cuda()
lat_dis = LatentDiscriminator(params).cuda() if params.n_lat_dis else None
ptc_dis = PatchDiscriminator(params).cuda() if params.n_ptc_dis else None
clf_dis = Classifier(params).cuda() if params.n_clf_dis else None
eval_clf = torch.load(params.eval_clf).cuda().eval()

# trainer / evaluator
trainer = Trainer(ae, lat_dis, ptc_dis, clf_dis, train_data, params)
evaluator = Evaluator(ae, lat_dis, ptc_dis, clf_dis, eval_clf, valid_data, params)


for n_epoch in range(params.n_epochs):
params.img_sz = ae.img_sz
params.attr = ae.attr
params.n_attr = ae.n_attr

# load dataset
data, attributes = load_images(params)
#test_data = DataSampler(data[2], attributes[2], params)

if params.dataset == 'train':
    data_ix = 0
elif params.dataset == 'val':
    data_ix = 1
elif params.dataset == 'test':
    data_ix = 2

test_data = DataSampler(data[data_ix], attributes[data_ix], params)


def get_interpolations(ae, images, attributes, params, alphas):
    """
    Reconstruct images / create interpolations
    """
    ae.eval()

    assert len(images) == len(attributes)
    enc_outputs = ae.encode(images)

    # separate latent code and attribute prediction
    bs = enc_outputs[0].size(0)

    z_all = enc_outputs[-1]  # full latent code
# create logger / load trained model
logger = create_logger(None)
ae = torch.load(params.model_path).eval()

# restore main parameters
params.debug = True
params.batch_size = 32
params.v_flip = False
params.h_flip = False
params.img_sz = ae.img_sz
params.attr = ae.attr
params.n_attr = ae.n_attr

# load dataset
data, attributes = load_images(params)
test_data = DataSampler(data[2], attributes[2], params)


def get_interpolations_2dim(ae, images, attributes, params):
    """
    Reconstruct images / create interpolations two dimensionally
    """
    assert len(images) == len(attributes)
    enc_outputs = ae.encode(images)
    #pdb.set_trace()

    # interpolation values
    alphas_1 = np.linspace(1 - params.alpha_min_1, params.alpha_max_1,
                           params.n_interpolations)
    alphas_2 = np.linspace(1 - params.alpha_min_2, params.alpha_max_2,
                           params.n_interpolations)
Example #4
0
assert params.n_skip <= params.n_layers - 1
assert params.deconv_method in ['convtranspose', 'upsampling', 'pixelshuffle']
assert 0 <= params.smooth_label < 0.5
assert not params.ae_reload or os.path.isfile(params.ae_reload)
assert not params.lat_dis_reload or os.path.isfile(params.lat_dis_reload)
assert not params.ptc_dis_reload or os.path.isfile(params.ptc_dis_reload)
assert not params.clf_dis_reload or os.path.isfile(params.clf_dis_reload)
assert os.path.isfile(params.eval_clf)
assert params.lambda_lat_dis == 0 or params.n_lat_dis > 0
assert params.lambda_ptc_dis == 0 or params.n_ptc_dis > 0
assert params.lambda_clf_dis == 0 or params.n_clf_dis > 0

# initialize experiment / load dataset
logger = initialize_exp(params)
train_val_test_images, train_val_test_attrs = load_images(params)
train_data = DataSampler(train_val_test_images[0], train_val_test_attrs[0],
                         params)
valid_data = DataSampler(train_val_test_images[1], train_val_test_attrs[1],
                         params)

# build the model
ae = AutoEncoder(params).cuda()
lat_dis = LatentDiscriminator(params).cuda() if params.n_lat_dis else None
ptc_dis = PatchDiscriminator(params).cuda() if params.n_ptc_dis else None
clf_dis = Classifier(params).cuda() if params.n_clf_dis else None
eval_clf = torch.load(params.eval_clf).cuda().eval()

# trainer / evaluator
trainer = Trainer(ae, lat_dis, ptc_dis, clf_dis, train_data, params)
evaluator = Evaluator(ae, lat_dis, ptc_dis, clf_dis, eval_clf, valid_data,
                      params)
Example #5
0
parser.add_argument("--n_source", type=int, default=193390,
                    help="number of source images")

params = parser.parse_args()

params.model_type = "classifier"

# check parameters
check_attr(params)
assert len(params.name.strip()) > 0
assert not params.reload or os.path.isfile(params.reload)

# initialize experiment / load dataset
logger = initialize_exp(params)
data, attributes = load_images(params)
train_data = DataSampler(data[0], attributes[0], params)
valid_data = DataSampler(data[1], attributes[1], params)
test_data = DataSampler(data[2], attributes[2], params)

# build the model / reload / optimizer
classifier = Classifier(params).cuda()
if params.reload:
    reload_model(classifier, params.reload,
                 ['img_sz', 'img_fm', 'init_fm', 'hid_dim', 'attr', 'n_attr'])
optimizer = get_optimizer(classifier, params.optimizer)


def save_model(name):
    """
    Save the model.
    """
Example #6
0
def interpolate(ae, n_epoch):
    def get_interpolations(ae, images, attributes, params):
        """
        Reconstruct images / create interpolations
        """
        assert len(images) == len(attributes)
        enc_outputs = ae.encode(images)

        # interpolation values
        alphas = np.linspace(1 - params.alpha_min, params.alpha_max,
                             params.n_interpolations)
        alphas = [torch.FloatTensor([1 - alpha, alpha]) for alpha in alphas]

        # original image / reconstructed image / interpolations
        outputs = []
        outputs.append(images)
        outputs.append(ae.decode(enc_outputs, attributes)[-1])
        for alpha in alphas:
            alpha = Variable(
                alpha.unsqueeze(0).expand((len(images), 2)).cuda())
            outputs.append(ae.decode(enc_outputs, alpha)[-1])

        # return stacked images
        return torch.cat([x.unsqueeze(1) for x in outputs], 1).data.cpu()

    params = parameters.interpolateParams()
    # ae = torch.load(params.model_path).eval()
    params.debug = True
    params.batch_size = 50
    params.v_flip = False
    params.h_flip = False
    params.img_sz = ae.img_sz
    params.attr = ae.attr
    params.n_attr = ae.n_attr
    if not (len(params.attr) == 1 and params.n_attr == 2):
        raise Exception("The model must use a single boolean attribute only.")
    data, attributes = load_images(params)
    test_data = DataSampler(data[2], attributes[2], params)

    interpolations = []

    for k in range(0, params.n_images, 100):
        i = params.offset + k
        j = params.offset + min(params.n_images, k + 100)
        images, attributes = test_data.eval_batch(i, j)
        interpolations.append(
            get_interpolations(ae, images, attributes, params))

        interpolations = torch.cat(interpolations, 0)
        assert interpolations.size() == (params.n_images,
                                         2 + params.n_interpolations, 3,
                                         params.img_sz, params.img_sz)

    def get_grid(images, row_wise, plot_size=5):
        """
        Create a grid with all images.
        """
        n_images, n_columns, img_fm, img_sz, _ = images.size()
        if not row_wise:
            images = images.transpose(0, 1).contiguous()
        images = images.view(n_images * n_columns, img_fm, img_sz, img_sz)
        images.add_(1).div_(2.0)
        return make_grid(images, nrow=(n_columns if row_wise else n_images))

    # generate the grid / save it to a PNG file
    grid = get_grid(interpolations, params.row_wise, params.plot_size)
    matplotlib.image.imsave(params.output_path + str(n_epoch) + ".png",
                            grid.numpy().transpose((1, 2, 0)))