예제 #1
0
def grid_from_latents(z,
                      dmodel,
                      rows,
                      cols,
                      anchor_images,
                      tight,
                      shoulders,
                      save_path,
                      args=None,
                      batch_size=24,
                      template_dict={},
                      emb_l=None):
    use_embedded = True
    if emb_l is None:
        emb_l = [None] * len(z)
        use_embedded = False

    z_queue = z[:]
    e_queue = emb_l[:]
    samples = None
    # print("========> DECODING {} at a time".format(batch_size))
    while (len(z_queue) > 0):
        cur_z = z_queue[:batch_size]
        cur_e = e_queue[:batch_size]
        z_queue = z_queue[batch_size:]
        e_queue = e_queue[batch_size:]
        if use_embedded:
            decoded = dmodel.decode_embedded(cur_z, cur_e)
        else:
            decoded = dmodel.sample_at(cur_z)
        if samples is None:
            samples = decoded
        else:
            samples = np.concatenate((samples, decoded), axis=0)

    # samples = dmodel.sample_at(z)

    if shoulders:
        samples, rows, cols = add_shoulders(samples, anchor_images, rows, cols)

    try:
        one_sample = next(item for item in samples if item is not None)
    except StopIteration:
        print("No samples found to save")
        return

    # each sample is 3xsizexsize
    template_dict["SIZE"] = one_sample.shape[1]
    final_save_path = emit_filename(save_path, template_dict, args)
    after_last_slash = final_save_path.rfind("/") + 1
    outfile_temp = final_save_path[:after_last_slash] + '_' + final_save_path[
        after_last_slash:]
    print("Saving image file {}".format(final_save_path))
    dirname = os.path.dirname(final_save_path)
    if dirname != '' and not os.path.exists(dirname):
        os.makedirs(dirname)
    img = grid2img(samples, rows, cols, not tight)
    img.save(outfile_temp)
    os.rename(outfile_temp, final_save_path)
    os.system("touch {}".format(final_save_path))
예제 #2
0
def grid_from_latents(z,
                      dmodel,
                      rows,
                      cols,
                      anchor_images,
                      tight,
                      shoulders,
                      save_path,
                      batch_size=24):
    z_queue = z[:]
    samples = None
    # print("========> DECODING {} at a time".format(batch_size))
    while (len(z_queue) > 0):
        cur_z = z_queue[:batch_size]
        z_queue = z_queue[batch_size:]
        decoded = dmodel.sample_at(cur_z)
        if samples is None:
            samples = decoded
        else:
            samples = np.concatenate((samples, decoded), axis=0)

    # samples = dmodel.sample_at(z)

    if shoulders:
        samples, rows, cols = add_shoulders(samples, anchor_images, rows, cols)

    print('Preparing image grid...')
    img = grid2img(samples, rows, cols, not tight)
    img.save(save_path)
예제 #3
0
파일: sample.py 프로젝트: dribnet/discgen
def grid_from_latents(z, dmodel, rows, cols, anchor_images, tight, shoulders, save_path, batch_size=24):
    z_queue = z[:]
    samples = None
    # print("========> DECODING {} at a time".format(batch_size))
    while(len(z_queue) > 0):
        cur_z = z_queue[:batch_size]
        z_queue = z_queue[batch_size:]
        decoded = dmodel.sample_at(cur_z)
        if samples is None:
            samples = decoded
        else:
            samples = np.concatenate((samples, decoded), axis=0)

    # samples = dmodel.sample_at(z)

    if shoulders:
        samples, rows, cols = add_shoulders(samples, anchor_images, rows, cols)

    print('Preparing image grid...')
    img = grid2img(samples, rows, cols, not tight)
    img.save(save_path)
예제 #4
0
파일: sample.py 프로젝트: dribnet/discgen
def run_with_args(args, dmodel, cur_anchor_image, cur_save_path, cur_z_step):
    if args.seed is not None:
        np.random.seed(args.seed)
        random.seed(args.seed)

    anchor_images = None
    if args.anchors:
        _, get_anchor_images = lazy_init_fuel_dependencies()
        allowed = None
        prohibited = None
        include_targets = False
        if(args.allowed):
            include_targets = True
            allowed = map(int, args.allowed.split(","))
        if(args.prohibited):
            include_targets = True
            prohibited = map(int, args.prohibited.split(","))
        anchor_images = get_anchor_images(args.dataset, args.split, args.offset, args.stepsize, args.numanchors, allowed, prohibited, args.image_size, args.color_convert, include_targets=include_targets)

    if cur_anchor_image is not None:
        _, _, anchor_images = anchors_from_image(cur_anchor_image, image_size=(args.image_size, args.image_size))
        if args.offset > 0:
            anchor_images = anchor_images[args.offset:]
        # untested
        if args.numanchors is not None:
            anchor_images = anchor_images[:args.numanchors]

    if args.passthrough:
        print('Preparing image grid...')
        img = grid2img(anchor_images, args.rows, args.cols, not args.tight)
        img.save(cur_save_path)
        sys.exit(0)

    if dmodel is None:
        model_class_parts = args.model_class.split(".")
        model_class_name = model_class_parts[-1]
        model_module_name = ".".join(model_class_parts[:-1])
        print("Loading {} interface from {}".format(model_class_name, model_module_name))        
        ModelClass = getattr(importlib.import_module(model_module_name), model_class_name)
        print("Loading model from {}".format(args.model))
        dmodel = ModelClass(filename=args.model)

    if anchor_images is not None:
        x_queue = anchor_images[:]
        anchors = None
        # print("========> ENCODING {} at a time".format(args.batch_size))
        while(len(x_queue) > 0):
            cur_x = x_queue[:args.batch_size]
            x_queue = x_queue[args.batch_size:]
            encoded = dmodel.encode_images(cur_x)
            if anchors is None:
                anchors = encoded
            else:
                anchors = np.concatenate((anchors, encoded), axis=0)

        # anchors = dmodel.encode_images(anchor_images)
    elif args.anchor_vectors is not None:
        anchors = get_json_vectors(args.anchor_vectors)
    else:
        anchors = None

    if args.invert_anchors:
        anchors = -1 * anchors

    if args.encoder:
        if anchors is not None:
            output_vectors(anchors)
        else:
            stream_output_vectors(dmodel, args.dataset, args.split, batch_size=args.batch_size)
        sys.exit(0)

    global_offset = None
    if args.anchor_offset is not None:
        # compute anchors as offsets from existing anchor
        offsets = get_json_vectors(args.anchor_offset)
        if args.anchor_noise:
            anchors = anchors_noise_offsets(anchors, offsets, args.rows, args.cols, args.spacing,
                cur_z_step, args.anchor_offset_x, args.anchor_offset_y,
                args.anchor_offset_x_minscale, args.anchor_offset_y_minscale, args.anchor_offset_x_maxscale, args.anchor_offset_y_maxscale)
        else:
            anchors = anchors_from_offsets(anchors[0], offsets, args.anchor_offset_x, args.anchor_offset_y,
                args.anchor_offset_x_minscale, args.anchor_offset_y_minscale, args.anchor_offset_x_maxscale, args.anchor_offset_y_maxscale)

    if args.global_offset is not None:
        offsets = get_json_vectors(args.global_offset)
        if args.global_ramp:
            offsets = cur_z_step * offsets
        global_offset =  get_global_offset(offsets, args.global_indices, args.global_scale)

    z_dim = dmodel.get_zdim()
    # I don't remember what partway/encircle do so they are not handling the chain layout
    # this handles the case (at least) of mines with random anchors
    if (args.partway is not None) or args.encircle or (args.mine and anchors is None):
        srows=((args.rows // args.spacing) + 1)
        scols=((args.cols // args.spacing) + 1)
        rand_anchors = generate_latent_grid(z_dim, rows=srows, cols=scols, fan=False, gradient=False,
            spherical=False, gaussian=False, anchors=None, anchor_images=None, mine=False, chain=False,
            spacing=args.spacing, analogy=False, rand_uniform=args.uniform)
        if args.partway is not None:
            l = len(rand_anchors)
            clipped_anchors = anchors[:l]
            anchors = (1.0 - args.partway) * rand_anchors + args.partway * clipped_anchors
        elif args.encircle:
            anchors = surround_anchors(srows, scols, anchors, rand_anchors)
        else:
            anchors = rand_anchors
    z = generate_latent_grid(z_dim, args.rows, args.cols, args.fan, args.gradient, not args.linear, args.gaussian,
            anchors, anchor_images, args.mine, args.chain, args.spacing, args.analogy)
    if global_offset is not None:
        z = z + global_offset

    grid_from_latents(z, dmodel, args.rows, args.cols, anchor_images, args.tight, args.shoulders, cur_save_path, args.batch_size)
    return dmodel
예제 #5
0
def run_with_args(args,
                  dmodel,
                  cur_anchor_image,
                  cur_save_path,
                  cur_z_step,
                  cur_basename="basename",
                  range_data=None,
                  template_dict={}):
    anchor_images = None
    anchor_labels = None
    if args.anchors:
        allowed = None
        prohibited = None
        include_targets = False
        if (args.allowed):
            include_targets = True
            allowed = map(int, args.allowed.split(","))
        if (args.prohibited):
            include_targets = True
            prohibited = map(int, args.prohibited.split(","))
        anchor_images = get_anchor_images(args.dataset,
                                          args.split,
                                          args.offset,
                                          args.stepsize,
                                          args.numanchors,
                                          allowed,
                                          prohibited,
                                          args.image_size,
                                          args.color_convert,
                                          include_targets=include_targets)
        if args.with_labels:
            anchor_labels = get_anchor_labels(args.dataset, args.split,
                                              args.offset, args.stepsize,
                                              args.numanchors)

    if args.anchor_glob is not None:
        files = plat.sampling.real_glob(args.anchor_glob)
        if args.offset > 0:
            files = files[args.offset:]
        if args.stepsize > 1:
            files = files[::args.stepsize]
        if args.numanchors is not None:
            files = files[:args.numanchors]
        anchor_images = anchors_from_filelist(files)
        print("Read {} images from {} files".format(len(anchor_images),
                                                    len(files)))
        if len(anchor_images) == 0:
            print("No images, cannot contine")
            sys.exit(0)

    if cur_anchor_image is not None:
        _, _, anchor_images = anchors_from_image(cur_anchor_image,
                                                 image_size=(args.image_size,
                                                             args.image_size))
        if args.offset > 0:
            anchor_images = anchor_images[args.offset:]
        if args.stepsize > 0:
            anchor_images = anchor_images[::args.stepsize]
        if args.numanchors is not None:
            anchor_images = anchor_images[:args.numanchors]

    # at this point we can make a dummy anchor_labels if we need
    if anchor_images is not None and anchor_labels is None:
        anchor_labels = [None] * len(anchor_images)

    if args.passthrough:
        # determine final filename string
        image_size = anchor_images[0].shape[1]
        save_path = plat.sampling.emit_filename(cur_save_path, {}, args)
        print("Preparing image file {}".format(save_path))
        img = grid2img(anchor_images, args.rows, args.cols, not args.tight)
        img.save(save_path)
        sys.exit(0)

    if dmodel is None:
        dmodel = zoo.load_model(args.model, args.model_file, args.model_type,
                                args.model_interface)

    embedded = None
    if anchor_images is not None:
        x_queue = anchor_images[:]
        c_queue = anchor_labels[:]
        anchors = None
        # print("========> ENCODING {} at a time".format(args.batch_size))
        while (len(x_queue) > 0):
            cur_x = x_queue[:args.batch_size]
            cur_c = c_queue[:args.batch_size]
            x_queue = x_queue[args.batch_size:]
            c_queue = c_queue[args.batch_size:]
            encoded = dmodel.encode_images(cur_x, cur_c)
            try:
                emb_l = dmodel.embed_labels(cur_c)
            except AttributeError:
                emb_l = [None] * args.batch_size
            if anchors is None:
                anchors = encoded
                embedded = emb_l
            else:
                anchors = np.concatenate((anchors, encoded), axis=0)
                embedded = np.concatenate((embedded, emb_l), axis=0)

        # anchors = dmodel.encode_images(anchor_images)
    elif args.anchor_vectors is not None:
        anchors = get_json_vectors(args.anchor_vectors)
    else:
        anchors = None

    if args.invert_anchors:
        anchors = -1 * anchors

    if args.encoder:
        if anchors is not None:
            plat.sampling.output_vectors(anchors, args.save_path)
        else:
            plat.sampling.stream_output_vectors(dmodel,
                                                args.dataset,
                                                args.split,
                                                args.save_path,
                                                batch_size=args.batch_size)
        sys.exit(0)

    global_offset = None
    if args.anchor_offset is not None:
        # compute anchors as offsets from existing anchor
        offsets = get_json_vectors(args.anchor_offset)
        if args.anchor_wave:
            anchors = plat.sampling.anchors_wave_offsets(
                anchors, offsets, args.rows, args.cols, args.spacing,
                args.radial_wave, args.clip_wave, cur_z_step,
                args.anchor_offset_x, args.anchor_offset_x_minscale,
                args.anchor_offset_x_maxscale)
        elif args.anchor_noise:
            anchors = plat.sampling.anchors_noise_offsets(
                anchors, offsets, args.rows, args.cols, args.spacing,
                cur_z_step, args.anchor_offset_x, args.anchor_offset_y,
                args.anchor_offset_x_minscale, args.anchor_offset_y_minscale,
                args.anchor_offset_x_maxscale, args.anchor_offset_y_maxscale)
        elif range_data is not None:
            anchors = plat.sampling.anchors_json_offsets(
                anchors, offsets, args.rows, args.cols, args.spacing,
                cur_z_step, args.anchor_offset_x, args.anchor_offset_y,
                args.anchor_offset_x_minscale, args.anchor_offset_y_minscale,
                args.anchor_offset_x_maxscale, args.anchor_offset_y_maxscale,
                range_data)
        else:
            anchors = plat.sampling.anchors_from_offsets(
                anchors[0], offsets, args.anchor_offset_x,
                args.anchor_offset_y, args.anchor_offset_x_minscale,
                args.anchor_offset_y_minscale, args.anchor_offset_x_maxscale,
                args.anchor_offset_y_maxscale)

    if args.global_offset is not None:
        offsets = get_json_vectors(args.global_offset)
        if args.global_ramp:
            offsets = cur_z_step * offsets
        global_offset = plat.sampling.get_global_offset(
            offsets, args.global_indices, args.global_scale)

    z_dim = dmodel.get_zdim()
    # I don't remember what partway/encircle do so they are not handling the chain layout
    # this handles the case (at least) of mines with random anchors
    if (args.partway is not None) or args.encircle or (anchors is None):
        srows = ((args.rows // args.spacing) + 1)
        scols = ((args.cols // args.spacing) + 1)
        rand_anchors = plat.sampling.generate_latent_grid(
            z_dim,
            rows=srows,
            cols=scols,
            fan=False,
            gradient=False,
            spherical=False,
            gaussian=False,
            anchors=None,
            anchor_images=None,
            mine=False,
            chain=False,
            spacing=args.spacing,
            analogy=False,
            rand_uniform=args.uniform)
        if args.partway is not None:
            l = len(rand_anchors)
            clipped_anchors = anchors[:l]
            anchors = (1.0 - args.partway
                       ) * rand_anchors + args.partway * clipped_anchors
        elif args.encircle:
            anchors = surround_anchors(srows, scols, anchors, rand_anchors)
        else:
            anchors = rand_anchors
    z = plat.sampling.generate_latent_grid(z_dim, args.rows, args.cols,
                                           args.fan, args.gradient,
                                           not args.linear, args.gaussian,
                                           anchors, anchor_images, True,
                                           args.chain, args.spacing,
                                           args.analogy)
    if global_offset is not None:
        z = z + global_offset

    template_dict["BASENAME"] = cur_basename
    # emb_l = None
    # emb_l = [None] * len(z)
    if args.clone_label is not None:
        emb_l = np.tile(embedded[args.clone_label], [len(z), 1])
    else:
        emb_l = plat.sampling.generate_latent_grid(
            z_dim, args.rows, args.cols, args.fan, args.gradient,
            not args.linear, args.gaussian, embedded, anchor_images, True,
            args.chain, args.spacing, args.analogy)

    #TODO - maybe not best way to check if labels are valid
    # if anchor_labels is None or anchor_labels[0] is None:
    #     emb_l = [None] * len(z)
    plat.sampling.grid_from_latents(z,
                                    dmodel,
                                    args.rows,
                                    args.cols,
                                    anchor_images,
                                    args.tight,
                                    args.shoulders,
                                    cur_save_path,
                                    args,
                                    args.batch_size,
                                    template_dict=template_dict,
                                    emb_l=emb_l)
    return dmodel
예제 #6
0
파일: sample.py 프로젝트: dribnet/plat
def run_with_args(args, dmodel, cur_anchor_image, cur_save_path, cur_z_step, cur_basename="basename", range_data=None, template_dict={}):
    anchor_images = None
    anchor_labels = None
    if args.anchors:
        allowed = None
        prohibited = None
        include_targets = False
        if(args.allowed):
            include_targets = True
            allowed = map(int, args.allowed.split(","))
        if(args.prohibited):
            include_targets = True
            prohibited = map(int, args.prohibited.split(","))
        anchor_images = get_anchor_images(args.dataset, args.split, args.offset, args.stepsize, args.numanchors, allowed, prohibited, args.image_size, args.color_convert, include_targets=include_targets)
        if args.with_labels:
            anchor_labels = get_anchor_labels(args.dataset, args.split, args.offset, args.stepsize, args.numanchors)

    if args.anchor_glob is not None:
        files = plat.sampling.real_glob(args.anchor_glob)
        if args.offset > 0:
            files = files[args.offset:]
        if args.stepsize > 1:
            files = files[::args.stepsize]
        if args.numanchors is not None:
            files = files[:args.numanchors]
        anchor_images = anchors_from_filelist(files, args.channels)
        print("Read {} images from {} files".format(len(anchor_images), len(files)))
        print("First 5 files: ", files[:5])
        if len(anchor_images) == 0:
            print("No images, cannot contine")
            sys.exit(0)

    if cur_anchor_image is not None:
        # _, _, anchor_images = anchors_from_image(cur_anchor_image, channels=args.channels, image_size=(args.image_size, args.image_size))
        anchor_images = anchors_from_filelist([cur_anchor_image], channels=args.channels)
        if args.offset > 0:
            anchor_images = anchor_images[args.offset:]
        if args.stepsize > 0:
            anchor_images = anchor_images[::args.stepsize]
        if args.numanchors is not None:
            anchor_images = anchor_images[:args.numanchors]

    # at this point we can make a dummy anchor_labels if we need
    if anchor_images is not None and anchor_labels is None:
        anchor_labels = [None] * len(anchor_images)

    if args.passthrough:
        # determine final filename string
        image_size = anchor_images[0].shape[1]
        save_path = plat.sampling.emit_filename(cur_save_path, {}, args);
        print("Preparing image file {}".format(save_path))
        img = grid2img(anchor_images, args.rows, args.cols, not args.tight)
        img.save(save_path)
        sys.exit(0)

    if dmodel is None:
        dmodel = zoo.load_model(args.model, args.model_file, args.model_type, args.model_interface)

    if args.seed is not None:
        print("Setting random seed to ", args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)
    else:
        np.random.seed(None)
        random.seed(None)

    embedded = None
    if anchor_images is not None:
        x_queue = anchor_images[:]
        c_queue = anchor_labels[:]
        anchors = None
        # print("========> ENCODING {} at a time".format(args.batch_size))
        while(len(x_queue) > 0):
            cur_x = x_queue[:args.batch_size]
            cur_c = c_queue[:args.batch_size]
            x_queue = x_queue[args.batch_size:]
            c_queue = c_queue[args.batch_size:]
            # TODO: remove vestiges of conditional encode/decode
            # encoded = dmodel.encode_images(cur_x, cur_c)
            encoded = dmodel.encode_images(cur_x)
            try:
                emb_l = dmodel.embed_labels(cur_c)
            except AttributeError:
                emb_l = [None] * args.batch_size
            if anchors is None:
                anchors = encoded
                embedded = emb_l
            else:
                anchors = np.concatenate((anchors, encoded), axis=0)
                embedded = np.concatenate((embedded, emb_l), axis=0)

        # anchors = dmodel.encode_images(anchor_images)
    elif args.anchor_vectors is not None:
        anchors = get_json_vectors(args.anchor_vectors)
        # print("Read vectors: ", anchors.shape)
        vsize = anchors.shape[-1]
        anchors = anchors.reshape([-1, vsize])
        print("Read vectors: ", anchors.shape)
    else:
        anchors = None

    if args.invert_anchors:
        anchors = -1 * anchors

    if args.encoder:
        if anchors is not None:
            plat.sampling.output_vectors(anchors, args.save_path)
        else:
            plat.sampling.stream_output_vectors(dmodel, args.dataset, args.split, args.save_path, batch_size=args.batch_size)
        sys.exit(0)

    global_offset = None
    if args.anchor_offset is not None:
        # compute anchors as offsets from existing anchor
        offsets = get_json_vectors_list(args.anchor_offset)
        if args.anchor_wave:
            anchors = plat.sampling.anchors_wave_offsets(anchors, offsets, args.rows, args.cols, args.spacing,
                args.radial_wave, args.clip_wave, cur_z_step, args.anchor_offset_x,
                args.anchor_offset_x_minscale, args.anchor_offset_x_maxscale)
        elif args.anchor_noise:
            anchors = plat.sampling.anchors_noise_offsets(anchors, offsets, args.rows, args.cols, args.spacing,
                cur_z_step, args.anchor_offset_x, args.anchor_offset_y,
                args.anchor_offset_x_minscale, args.anchor_offset_y_minscale, args.anchor_offset_x_maxscale, args.anchor_offset_y_maxscale)
        elif range_data is not None:
            anchors = plat.sampling.anchors_json_offsets(anchors, offsets, args.rows, args.cols, args.spacing,
                cur_z_step, args.anchor_offset_x, args.anchor_offset_y,
                args.anchor_offset_x_minscale, args.anchor_offset_y_minscale, args.anchor_offset_x_maxscale, args.anchor_offset_y_maxscale,
                range_data)
        else:
            anchors = plat.sampling.anchors_from_offsets(anchors[0], offsets, args.anchor_offset_x, args.anchor_offset_y,
                args.anchor_offset_x_minscale, args.anchor_offset_y_minscale, args.anchor_offset_x_maxscale, args.anchor_offset_y_maxscale)

    if args.global_offset is not None:
        offsets = get_json_vectors(args.global_offset)
        if args.global_ramp:
            offsets = cur_z_step * offsets
        global_offset =  plat.sampling.get_global_offset(offsets, args.global_indices, args.global_scale)

    z_dim = dmodel.get_zdim()
    # I don't remember what partway/encircle do so they are not handling the chain layout
    # this handles the case (at least) of mines with random anchors
    if (args.partway is not None) or args.encircle or (anchors is None):
        srows=((args.rows // args.spacing) + 1)
        scols=((args.cols // args.spacing) + 1)
        rand_anchors = plat.sampling.generate_latent_grid(z_dim, rows=srows, cols=scols, fan=False, gradient=False,
            spherical=False, gaussian=False, anchors=None, anchor_images=None, mine=False, chain=False,
            spacing=args.spacing, analogy=False, rand_uniform=args.uniform)
        if args.partway is not None:
            l = len(rand_anchors)
            clipped_anchors = anchors[:l]
            anchors = (1.0 - args.partway) * rand_anchors + args.partway * clipped_anchors
        elif args.encircle:
            anchors = surround_anchors(srows, scols, anchors, rand_anchors)
        else:
            anchors = rand_anchors
    z = plat.sampling.generate_latent_grid(z_dim, args.rows, args.cols, args.fan, args.gradient, not args.linear, args.gaussian,
            anchors, anchor_images, True, args.chain, args.spacing, args.analogy)
    if args.write_anchors:
        plat.sampling.output_vectors(anchors, "anchors.json")

    if global_offset is not None:
        z = z + global_offset

    template_dict["BASENAME"] = cur_basename
    # emb_l = None
    # emb_l = [None] * len(z)
    embedded_labels = None
    # TODO: this could be more elegant
    if embedded is not None and embedded[0] is not None:
        if args.clone_label is not None:
            embedded_labels = np.tile(embedded[args.clone_label], [len(z), 1])
        else:
            embedded_labels = plat.sampling.generate_latent_grid(z_dim, args.rows, args.cols, args.fan, args.gradient, not args.linear, args.gaussian,
                    embedded, anchor_images, True, args.chain, args.spacing, args.analogy)

    #TODO - maybe not best way to check if labels are valid
    # if anchor_labels is None or anchor_labels[0] is None:
    #     emb_l = [None] * len(z)
    plat.sampling.grid_from_latents(z, dmodel, args.rows, args.cols, anchor_images, args.tight, args.shoulders, cur_save_path, args, args.batch_size, template_dict=template_dict, emb_l=embedded_labels)
    return dmodel
예제 #7
0
def run_with_args(args, dmodel, cur_anchor_image, cur_save_path, cur_z_step):
    if args.seed is not None:
        np.random.seed(args.seed)
        random.seed(args.seed)

    anchor_images = None
    if args.anchors:
        _, get_anchor_images = lazy_init_fuel_dependencies()
        allowed = None
        prohibited = None
        include_targets = False
        if (args.allowed):
            include_targets = True
            allowed = map(int, args.allowed.split(","))
        if (args.prohibited):
            include_targets = True
            prohibited = map(int, args.prohibited.split(","))
        anchor_images = get_anchor_images(args.dataset,
                                          args.split,
                                          args.offset,
                                          args.stepsize,
                                          args.numanchors,
                                          allowed,
                                          prohibited,
                                          args.image_size,
                                          args.color_convert,
                                          include_targets=include_targets)

    if cur_anchor_image is not None:
        _, _, anchor_images = anchors_from_image(cur_anchor_image,
                                                 image_size=(args.image_size,
                                                             args.image_size))
        if args.offset > 0:
            anchor_images = anchor_images[args.offset:]
        # untested
        if args.numanchors is not None:
            anchor_images = anchor_images[:args.numanchors]

    if args.passthrough:
        print('Preparing image grid...')
        img = grid2img(anchor_images, args.rows, args.cols, not args.tight)
        img.save(cur_save_path)
        sys.exit(0)

    if dmodel is None:
        model_class_parts = args.model_class.split(".")
        model_class_name = model_class_parts[-1]
        model_module_name = ".".join(model_class_parts[:-1])
        print("Loading {} interface from {}".format(model_class_name,
                                                    model_module_name))
        ModelClass = getattr(importlib.import_module(model_module_name),
                             model_class_name)
        print("Loading model from {}".format(args.model))
        dmodel = ModelClass(filename=args.model)

    if anchor_images is not None:
        x_queue = anchor_images[:]
        anchors = None
        # print("========> ENCODING {} at a time".format(args.batch_size))
        while (len(x_queue) > 0):
            cur_x = x_queue[:args.batch_size]
            x_queue = x_queue[args.batch_size:]
            encoded = dmodel.encode_images(cur_x)
            if anchors is None:
                anchors = encoded
            else:
                anchors = np.concatenate((anchors, encoded), axis=0)

        # anchors = dmodel.encode_images(anchor_images)
    elif args.anchor_vectors is not None:
        anchors = get_json_vectors(args.anchor_vectors)
    else:
        anchors = None

    if args.invert_anchors:
        anchors = -1 * anchors

    if args.encoder:
        if anchors is not None:
            output_vectors(anchors)
        else:
            stream_output_vectors(dmodel,
                                  args.dataset,
                                  args.split,
                                  batch_size=args.batch_size)
        sys.exit(0)

    global_offset = None
    if args.anchor_offset is not None:
        # compute anchors as offsets from existing anchor
        offsets = get_json_vectors(args.anchor_offset)
        if args.anchor_noise:
            anchors = anchors_noise_offsets(
                anchors, offsets, args.rows, args.cols, args.spacing,
                cur_z_step, args.anchor_offset_x, args.anchor_offset_y,
                args.anchor_offset_x_minscale, args.anchor_offset_y_minscale,
                args.anchor_offset_x_maxscale, args.anchor_offset_y_maxscale)
        else:
            anchors = anchors_from_offsets(
                anchors[0], offsets, args.anchor_offset_x,
                args.anchor_offset_y, args.anchor_offset_x_minscale,
                args.anchor_offset_y_minscale, args.anchor_offset_x_maxscale,
                args.anchor_offset_y_maxscale)

    if args.global_offset is not None:
        offsets = get_json_vectors(args.global_offset)
        if args.global_ramp:
            offsets = cur_z_step * offsets
        global_offset = get_global_offset(offsets, args.global_indices,
                                          args.global_scale)

    z_dim = dmodel.get_zdim()
    # I don't remember what partway/encircle do so they are not handling the chain layout
    # this handles the case (at least) of mines with random anchors
    if (args.partway is not None) or args.encircle or (args.mine
                                                       and anchors is None):
        srows = ((args.rows // args.spacing) + 1)
        scols = ((args.cols // args.spacing) + 1)
        rand_anchors = generate_latent_grid(z_dim,
                                            rows=srows,
                                            cols=scols,
                                            fan=False,
                                            gradient=False,
                                            spherical=False,
                                            gaussian=False,
                                            anchors=None,
                                            anchor_images=None,
                                            mine=False,
                                            chain=False,
                                            spacing=args.spacing,
                                            analogy=False,
                                            rand_uniform=args.uniform)
        if args.partway is not None:
            l = len(rand_anchors)
            clipped_anchors = anchors[:l]
            anchors = (1.0 - args.partway
                       ) * rand_anchors + args.partway * clipped_anchors
        elif args.encircle:
            anchors = surround_anchors(srows, scols, anchors, rand_anchors)
        else:
            anchors = rand_anchors
    z = generate_latent_grid(z_dim, args.rows, args.cols, args.fan,
                             args.gradient, not args.linear, args.gaussian,
                             anchors, anchor_images, args.mine, args.chain,
                             args.spacing, args.analogy)
    if global_offset is not None:
        z = z + global_offset

    grid_from_latents(z, dmodel, args.rows, args.cols, anchor_images,
                      args.tight, args.shoulders, cur_save_path,
                      args.batch_size)
    return dmodel