示例#1
0
    def __init__(
            self,
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            # customs
            embed_dim=256,
            encoder_type="impala",
            **kwargs):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        self.action_dim = action_space.n
        self.discrete = True
        self.action_outs = q_outs = self.action_dim
        self.action_ins = None  # No action inputs for the discrete case.
        self.embed_dim = embed_dim

        h, w, c = obs_space.shape
        shape = (c, h, w)
        # obs embedding
        self.encoder = make_encoder(encoder_type,
                                    shape,
                                    out_features=embed_dim)

        self.logits_fc = nn.Linear(in_features=embed_dim,
                                   out_features=num_outputs)
        self.value_fc = nn.Linear(in_features=embed_dim, out_features=1)
示例#2
0
    def __init__(
            self,
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            # customs
            embed_dim=256,
            encoder_type="impala",
            augmentation=False,
            aug_num=2,
            max_shift=4,
            **kwargs):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        self.action_dim = action_space.n
        self.discrete = True
        self.action_outs = q_outs = self.action_dim
        self.action_ins = None  # No action inputs for the discrete case.
        self.embed_dim = embed_dim

        h, w, c = obs_space.shape
        shape = (c, h, w)
        # obs embedding
        self.encoder = make_encoder(encoder_type,
                                    shape,
                                    out_features=embed_dim)

        self.logits_fc = nn.Linear(in_features=embed_dim,
                                   out_features=num_outputs)
        self.value_fc = nn.Linear(in_features=embed_dim, out_features=1)

        # customs
        self.augmentation = augmentation
        self.aug_num = aug_num
        if augmentation:
            obs_shape = obs_space.shape[-2]
            self.trans = nn.Sequential(nn.ReplicationPad2d(max_shift),
                                       RandomCrop((obs_shape, obs_shape)))

feature_dir = 'cui_emd_features/'

for task_id in range(50):
    if os.path.exists(feature_dir + str(task_id) + '_weight.npy'): continue
    cfg = dict_to_cfg({
        'name': 'cub_inat2018_no_aug',
        'task_id': task_id,
        'root': '/data'
    })
    train_dataset, _ = get_dataset('/data/', cfg)
    if hasattr(train_dataset, 'task_name'):
        print(f"======= Embedding for task: {train_dataset.task_name} =======")
    probe_network = make_encoder(
        get_model('resnet18',
                  pretraining='imagenet',
                  num_classes=train_dataset.num_classes)).cuda()
    data_loader = torch.utils.data.DataLoader(train_dataset,
                                              shuffle=False,
                                              batch_size=128,
                                              num_workers=1,
                                              drop_last=False)

    n_batches = len(data_loader)
    targets = []
    features = []
    probe_network.linear = torch.nn.Identity()
    for i, (input,
            target) in enumerate(itertools.islice(data_loader, 0, n_batches)):
        targets.extend(target.detach().cpu().numpy())
        features.extend(probe_network(input.cuda()).detach().cpu().numpy())
示例#4
0
    def __init__(
            self,
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            actor_hidden_activation="relu",
            actor_hiddens=(256, 256),
            critic_hidden_activation="relu",
            critic_hiddens=(256, 256),
            twin_q=False,
            initial_alpha=1.0,
            target_entropy=None,
            # customs
            embed_dim=50,
            augmentation=False,
            aug_num=2,
            max_shift=4,
            encoder_feature_dim=50,
            num_layers=4,
            num_filters=32,
            decoder_type='pixel',
            encoder_type='pixel',
            **kwargs):
        """Initialize variables of this model.

        Extra model kwargs:
            actor_hidden_activation (str): activation for actor network
            actor_hiddens (list): hidden layers sizes for actor network
            critic_hidden_activation (str): activation for critic network
            critic_hiddens (list): hidden layers sizes for critic network
            twin_q (bool): build twin Q networks.
            initial_alpha (float): The initial value for the to-be-optimized
                alpha parameter (default: 1.0).
            target_entropy (Optional[float]): An optional fixed value for the
                SAC alpha loss term. None or "auto" for automatic calculation
                of this value according to [1] (cont. actions) or [2]
                (discrete actions).

        Note that the core layers for forward() are not defined here, this
        only defines the layers for the output heads. Those layers for
        forward() should be defined in subclasses of SACModel.
        """
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        self.action_dim = action_space.n
        self.discrete = True
        self.action_outs = q_outs = self.action_dim
        self.action_ins = None  # No action inputs for the discrete case.
        self.embed_dim = embed_dim

        h, w, c = obs_space.shape
        shape = (c, h, w)
        obs_shape = shape

        # obs embedding
        # conv_seqs = []
        # for out_channels in [16, 32, 32]:
        #     conv_seq = ConvSequence(shape, out_channels)
        #     shape = conv_seq.get_output_shape()
        #     conv_seqs.append(conv_seq)
        # self.conv_seqs = nn.ModuleList(conv_seqs)
        # self.hidden_fc = nn.Linear(in_features=shape[0] * shape[1] * shape[2], out_features=embed_dim)

        # Build the policy network
        # TODO: fix func input
        self.actor_encoder = make_encoder(encoder_type, obs_shape,
                                          encoder_feature_dim, num_layers,
                                          num_filters)

        self.critic_encoder = make_encoder(encoder_type, obs_shape,
                                           encoder_feature_dim, num_layers,
                                           num_filters)

        # tie encoders between actor and critic
        self.actor_encoder.copy_conv_weights_from(self.critic_encoder)

        self.action_model = nn.Sequential()
        # img -> embedding
        ins = embed_dim
        act = get_activation_fn(actor_hidden_activation, framework="torch")
        init = nn.init.xavier_uniform_
        outs = self.actor_encoder.feature_dim
        # embedding to autoencoder embed
        self.action_model.add_module(
            "action_{}".format('e'),
            SlimFC(ins, outs, initializer=init, activation_fn=act))
        ins = outs
        # add trunk model
        for i, n in enumerate(actor_hiddens):
            self.action_model.add_module(
                "action_{}".format(i),
                SlimFC(ins, n, initializer=init, activation_fn=act))
            ins = n
        self.action_model.add_module(
            "action_out",
            SlimFC(ins, self.action_outs, initializer=init,
                   activation_fn=None))

        # Build the Q-net(s), including target Q-net(s).
        def build_q_net(name_):

            act = get_activation_fn(critic_hidden_activation,
                                    framework="torch")
            init = nn.init.xavier_uniform_
            # For discrete actions, only obs.
            q_net = nn.Sequential()
            ins = embed_dim
            # embed to encoder embed
            outs = self.critic_encoder.feature_dim
            q_net.add_module(
                "{}_hidden_{}".format(name_, "e"),
                SlimFC(ins, outs, initializer=init, activation_fn=act))
            ins = outs

            for i, n in enumerate(critic_hiddens):
                q_net.add_module(
                    "{}_hidden_{}".format(name_, i),
                    SlimFC(ins, n, initializer=init, activation_fn=act))
                ins = n

            q_net.add_module(
                "{}_out".format(name_),
                SlimFC(ins, q_outs, initializer=init, activation_fn=None))
            return q_net

        self.q_net = build_q_net("q")
        if twin_q:
            self.twin_q_net = build_q_net("twin_q")
        else:
            self.twin_q_net = None

        # temperature tensor
        self.log_alpha = torch.tensor(data=[np.log(initial_alpha)],
                                      dtype=torch.float32,
                                      requires_grad=True)

        # Auto-calculate the target entropy.
        if target_entropy is None or target_entropy == "auto":
            # See hyperparams in [2] (README.md).
            target_entropy = 0.98 * np.array(-np.log(1.0 / action_space.n),
                                             dtype=np.float32)

        self.target_entropy = torch.tensor(data=[target_entropy],
                                           dtype=torch.float32,
                                           requires_grad=False)

        # device = 0
        # make decoder
        self.decoder = None
        if decoder_type != 'identity':
            # create decoder
            self.decoder = make_decoder(decoder_type, obs_shape,
                                        encoder_feature_dim, num_layers,
                                        num_filters)
            self.decoder.apply(weight_init)

        # NOTE: custom fields
        self.augmentation = augmentation
        self.aug_num = aug_num
        # NOTE: augmentation
        if augmentation:
            obs_shape = obs_space.shape[-2]
            self.trans = nn.Sequential(
                nn.ReplicationPad2d(max_shift),
                kornia.augmentation.RandomCrop((obs_shape, obs_shape)))
示例#5
0
    def __init__(
            self,
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            *,
            dueling=False,
            q_hiddens=(256, ),
            dueling_activation="relu",
            use_noisy=False,
            sigma0=0.5,
            # TODO(sven): Move `add_layer_norm` into ModelCatalog as
            #  generic option, then error if we use ParameterNoise as
            #  Exploration type and do not have any LayerNorm layers in
            #  the net.
            add_layer_norm=False,
            num_atoms=1,
            v_min=-10.0,
            v_max=10.0,
            # customs
            embed_dim=50,
            encoder_type="pixel",
            num_layers=4,
            num_filters=32,
            cropped_image_size=54,
            **kwargs):
        """Initialize variables of this model.
        Extra model kwargs:
            dueling (bool): Whether to build the advantage(A)/value(V) heads
                for DDQN. If True, Q-values are calculated as:
                Q = (A - mean[A]) + V. If False, raw NN output is interpreted
                as Q-values.
            q_hiddens (List[int]): List of layer-sizes after(!) the
                Advantages(A)/Value(V)-split. Hence, each of the A- and V-
                branches will have this structure of Dense layers. To define
                the NN before this A/V-split, use - as always -
                config["model"]["fcnet_hiddens"].
            dueling_activation (str): The activation to use for all dueling
                layers (A- and V-branch). One of "relu", "tanh", "linear".
            use_noisy (bool): use noisy nets
            sigma0 (float): initial value of noisy nets
            add_layer_norm (bool): Enable layer norm (for param noise).
        """
        nn.Module.__init__(self)
        super(CurlRainbowTorchModel,
              self).__init__(obs_space, action_space, num_outputs,
                             model_config, name)

        # NOTE: customs
        self.embed_dim = embed_dim
        h, w, c = obs_space.shape
        shape = (c, h, w)
        # obs embedding
        self.encoder = make_encoder(encoder_type,
                                    shape,
                                    out_features=embed_dim)

        # NOTE: value output branches
        self.dueling = dueling
        # ins = num_outputs
        ins = embed_dim

        # Dueling case: Build the shared (advantages and value) fc-network.
        advantage_module = nn.Sequential()
        value_module = None
        # if to use noisy net
        layer_cls = NoisyLinear if use_noisy else nn.Linear

        if self.dueling:
            value_module = nn.Sequential()
            for i, n in enumerate(q_hiddens):

                # MLP layers
                advantage_module.add_module("dueling_A_{}".format(i),
                                            layer_cls(ins, n))
                value_module.add_module("dueling_V_{}".format(i),
                                        layer_cls(ins, n))

                # Add activations if necessary.
                if dueling_activation == "relu":
                    advantage_module.add_module("dueling_A_act_{}".format(i),
                                                nn.ReLU())
                    value_module.add_module("dueling_V_act_{}".format(i),
                                            nn.ReLU())
                elif dueling_activation == "tanh":
                    advantage_module.add_module("dueling_A_act_{}".format(i),
                                                nn.Tanh())
                    value_module.add_module("dueling_V_act_{}".format(i),
                                            nn.Tanh())

                # Add LayerNorm after each Dense.
                if add_layer_norm:
                    advantage_module.add_module("LayerNorm_A_{}".format(i),
                                                nn.LayerNorm(n))
                    value_module.add_module("LayerNorm_V_{}".format(i),
                                            nn.LayerNorm(n))
                ins = n

            # Actual Advantages layer (nodes=num-actions) and
            # value layer (nodes=1).
            advantage_module.add_module(
                "A", layer_cls(ins, action_space.n * num_atoms))
            value_module.add_module("V", layer_cls(ins, num_atoms))

        # Non-dueling:
        # Q-value layer (use main module's outputs as Q-values).
        else:
            # pass
            # NOTE: manually adding q value (no dueling) branch following embedding
            for i, n in enumerate(q_hiddens):
                advantage_module.add_module("Q_{}".format(i),
                                            layer_cls(ins, n))
                if dueling_activation == "relu":
                    advantage_module.add_module("Q_act_{}".format(i),
                                                nn.ReLU())
                elif dueling_activation == "tanh":
                    advantage_module.add_module("Q_act_{}".format(i),
                                                nn.Tanh())
                # Add LayerNorm after each Dense.
                if add_layer_norm:
                    advantage_module.add_module("LayerNorm_Q_{}".format(i),
                                                nn.LayerNorm(n))
                ins = n

            # Actual Q value layer (nodes=num-actions) and
            # value layer (nodes=1).
            advantage_module.add_module(
                "Q", layer_cls(ins, action_space.n * num_atoms))

        self.advantage_module = advantage_module
        self.value_module = value_module
        # distributional dqn settings
        self.num_atoms = num_atoms
        z = torch.arange(num_atoms).float()
        z = v_min + z * (v_max - v_min) / float(num_atoms - 1)
        self.z = z  # return distribution support

        # NOTE: input cropping
        self.cropped_image_size = cropped_image_size
        self.center_crop = CenterCrop(cropped_image_size)
        self.random_crop = RandomCrop((cropped_image_size, cropped_image_size))
示例#6
0
    def __init__(self,
                obs_space,
                action_space,
                num_outputs,
                model_config,
                name,
                actor_hidden_activation="relu",
                actor_hiddens=(256, 256),
                critic_hidden_activation="relu",
                critic_hiddens=(256, 256),
                twin_q=False,
                initial_alpha=1.0,
                target_entropy=None,
                #  customs 
                embed_dim = 256,
                encoder_type="impala",
                **kwargs):
        """Initialize variables of this model.

        Extra model kwargs:
            actor_hidden_activation (str): activation for actor network
            actor_hiddens (list): hidden layers sizes for actor network
            critic_hidden_activation (str): activation for critic network
            critic_hiddens (list): hidden layers sizes for critic network
            twin_q (bool): build twin Q networks.
            initial_alpha (float): The initial value for the to-be-optimized
                alpha parameter (default: 1.0).
            target_entropy (Optional[float]): An optional fixed value for the
                SAC alpha loss term. None or "auto" for automatic calculation
                of this value according to [1] (cont. actions) or [2]
                (discrete actions).

        Note that the core layers for forward() are not defined here, this
        only defines the layers for the output heads. Those layers for
        forward() should be defined in subclasses of SACModel.
        """
        TorchModelV2.__init__(self, obs_space, action_space,
                                            num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.action_dim = action_space.n
        self.discrete = True
        self.action_outs = q_outs = self.action_dim
        self.action_ins = None  # No action inputs for the discrete case.
        self.embed_dim = embed_dim
    
        h, w, c = obs_space.shape
        shape = (c, h, w)
        # obs embedding 
        self.encoder = make_encoder(encoder_type, shape, out_features=embed_dim)
  
        # Build the policy network.
        self.action_model = nn.Sequential()
        ins = embed_dim
        act = get_activation_fn(
            actor_hidden_activation, framework="torch")
        init = nn.init.xavier_uniform_

        for i, n in enumerate(actor_hiddens):
            self.action_model.add_module(
                "action_{}".format(i), 
                SlimFC(ins, n, initializer=init, activation_fn=act)
            )
            ins = n
        self.action_model.add_module(
            "action_out",
            SlimFC(ins, self.action_outs, initializer=init, activation_fn=None)
        )

        # Build the Q-net(s), including target Q-net(s).
        def build_q_net(name_):
            act = get_activation_fn(
                critic_hidden_activation, framework="torch")
            init = nn.init.xavier_uniform_
            # For discrete actions, only obs.
            q_net = nn.Sequential()
            ins = embed_dim
            for i, n in enumerate(critic_hiddens):
                q_net.add_module(
                    "{}_hidden_{}".format(name_, i),
                    SlimFC(ins, n, initializer=init, activation_fn=act)
                )
                ins = n

            q_net.add_module(
                "{}_out".format(name_),
                SlimFC(ins, q_outs, initializer=init, activation_fn=None)
            )
            return q_net

        self.q_net = build_q_net("q")
        if twin_q:
            self.twin_q_net = build_q_net("twin_q")
        else:
            self.twin_q_net = None

        # temperature tensor 
        self.log_alpha = torch.tensor(
            data=[np.log(initial_alpha)],
            dtype=torch.float32,
            requires_grad=True)

        # Auto-calculate the target entropy.
        if target_entropy is None or target_entropy == "auto":
            # See hyperparams in [2] (README.md).
            target_entropy = 0.98 * np.array(
                -np.log(1.0 / action_space.n), dtype=np.float32)
            
        self.target_entropy = torch.tensor(
            data=[target_entropy], dtype=torch.float32, requires_grad=False)
示例#7
0
    def __init__(
            self,
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            *,
            dueling=False,
            q_hiddens=(256, ),
            dueling_activation="relu",
            use_noisy=False,
            sigma0=0.5,
            # TODO(sven): Move `add_layer_norm` into ModelCatalog as
            #  generic option, then error if we use ParameterNoise as
            #  Exploration type and do not have any LayerNorm layers in
            #  the net.
            add_layer_norm=False,
            #  customs
            embed_dim=256,
            encoder_type="impala",
            **kwargs):
        """Initialize variables of this model.
        Extra model kwargs:
            dueling (bool): Whether to build the advantage(A)/value(V) heads
                for DDQN. If True, Q-values are calculated as:
                Q = (A - mean[A]) + V. If False, raw NN output is interpreted
                as Q-values.
            q_hiddens (List[int]): List of layer-sizes after(!) the
                Advantages(A)/Value(V)-split. Hence, each of the A- and V-
                branches will have this structure of Dense layers. To define
                the NN before this A/V-split, use - as always -
                config["model"]["fcnet_hiddens"].
            dueling_activation (str): The activation to use for all dueling
                layers (A- and V-branch). One of "relu", "tanh", "linear".
            use_noisy (bool): use noisy nets
            sigma0 (float): initial value of noisy nets
            add_layer_norm (bool): Enable layer norm (for param noise).
        """
        nn.Module.__init__(self)
        super(BaselineDQNTorchModel,
              self).__init__(obs_space, action_space, num_outputs,
                             model_config, name)

        # NOTE: customs
        self.embed_dim = embed_dim
        h, w, c = obs_space.shape
        shape = (c, h, w)
        # obs embedding
        self.encoder = make_encoder(encoder_type,
                                    shape,
                                    out_features=embed_dim)

        # NOTE: value output branches
        self.dueling = dueling
        # ins = num_outputså
        ins = embed_dim

        # Dueling case: Build the shared (advantages and value) fc-network.
        advantage_module = nn.Sequential()
        value_module = None
        if self.dueling:
            value_module = nn.Sequential()
            for i, n in enumerate(q_hiddens):
                advantage_module.add_module("dueling_A_{}".format(i),
                                            nn.Linear(ins, n))
                value_module.add_module("dueling_V_{}".format(i),
                                        nn.Linear(ins, n))
                # Add activations if necessary.
                if dueling_activation == "relu":
                    advantage_module.add_module("dueling_A_act_{}".format(i),
                                                nn.ReLU())
                    value_module.add_module("dueling_V_act_{}".format(i),
                                            nn.ReLU())
                elif dueling_activation == "tanh":
                    advantage_module.add_module("dueling_A_act_{}".format(i),
                                                nn.Tanh())
                    value_module.add_module("dueling_V_act_{}".format(i),
                                            nn.Tanh())

                # Add LayerNorm after each Dense.
                if add_layer_norm:
                    advantage_module.add_module("LayerNorm_A_{}".format(i),
                                                nn.LayerNorm(n))
                    value_module.add_module("LayerNorm_V_{}".format(i),
                                            nn.LayerNorm(n))
                ins = n
            # Actual Advantages layer (nodes=num-actions) and
            # value layer (nodes=1).
            advantage_module.add_module("A", nn.Linear(ins, action_space.n))
            value_module.add_module("V", nn.Linear(ins, 1))
        # Non-dueling:
        # Q-value layer (use main module's outputs as Q-values).
        else:
            # pass (UGLY HACK!!!)
            # NOTE: manually adding q value (no dueling) branch following embedding
            for i, n in enumerate(q_hiddens):
                advantage_module.add_module("Q_{}".format(i),
                                            nn.Linear(ins, n))
                if dueling_activation == "relu":
                    advantage_module.add_module("Q_act_{}".format(i),
                                                nn.ReLU())
                elif dueling_activation == "tanh":
                    advantage_module.add_module("Q_act_{}".format(i),
                                                nn.Tanh())
                # Add LayerNorm after each Dense.
                if add_layer_norm:
                    advantage_module.add_module("LayerNorm_Q_{}".format(i),
                                                nn.LayerNorm(n))
                ins = n
            # Actual Q value layer (nodes=num-actions) and
            # value layer (nodes=1).
            advantage_module.add_module("Q", nn.Linear(ins, action_space.n))

        self.advantage_module = advantage_module
        self.value_module = value_module
use_data_augmentation = not args.no_augmentation

decoder = load_model(args.decoder, custom_objects={'tf':tf, 'PixelShuffler':PixelShuffler, 'up_bilinear':up_bilinear})
for l in decoder.layers:
    l.trainable = False
decoder.trainable = False

BS = args.batch_size
EPOCHS = args.epochs
h, w, c = decoder.output_shape[-3:]
latent_dim = decoder.input_shape[-1]

train_generator = data_generator(args.dataset, height=h, width=w, channel=c, batch_size=BS, shuffle=True, normalize=not use_data_augmentation, save_tags=False)

encoder_model, encoder = make_encoder(decoder)

seq = get_imgaug()

if not os.path.exists('./preview'):
    os.makedirs('./preview')
    
def make_some_noise():
    return np.random.normal(0, args.std, (BS, latent_dim)).astype(np.float32)

i_counter = 0
AE = 0
for epoch in range(EPOCHS):
    print("Epoch: %d / %d"%(epoch+1, EPOCHS))
    train_generator.random_shuffle()
    with tqdm(total=len(train_generator)) as t:
示例#9
0
def train(
    _log: Logger,
    num_codes: int,
    latent_size: int,
    z_dimension: int,
    batch_size: int,
    beta: float,
):
    train_dataset, _ = create_dataset()

    x = train_dataset.batch(batch_size).make_one_shot_iterator().get_next()

    with tf.name_scope("data"):
        tf.summary.image("mnist_image", x)

    writer = create_writer()

    _log.info("Building computation graph...")
    encoder = make_encoder(latent_size, z_dimension)
    decoder = make_decoder(latent_size, z_dimension)
    quantizer = VectorQuantizer(num_codes, z_dimension)

    codes = encoder(x)
    nearest_codebook_entries, _ = quantizer(codes)

    # Pass nearest codebook entries to decoder and pass its gradients
    # straigth through to the encoder.
    # => Forward pass: nearest_codebook_entries, Backpropagation: codes
    codes_straight_through = codes + tf.stop_gradient(
        nearest_codebook_entries - codes)
    decoder_distribution = tf.distributions.Bernoulli(
        logits=decoder(codes_straight_through))

    print(encoder.summary())

    print(decoder.summary())

    with tf.variable_scope("posterior", reuse=True):
        posterior_sample = decoder_distribution.mean()
        tf.summary.image("posterior_sample",
                         tf.cast(posterior_sample, tf.float32))

    reconstruction_loss = -tf.reduce_mean(decoder_distribution.log_prob(x))
    commitment_loss = tf.reduce_mean(
        tf.square(codes - tf.stop_gradient(nearest_codebook_entries)))
    embedding_loss = tf.reduce_mean(
        tf.square(tf.stop_gradient(nearest_codebook_entries) - codes))

    # Uniform prior over codes
    prior_distribution = tf.distributions.Multinomial(
        total_count=1.0, logits=tf.zeros([latent_size, num_codes]))

    with tf.variable_scope("prior", reuse=True):
        prior_sample_codes = quantizer.get_codebook_entries(
            prior_distribution.sample(1))
        prior_decoder_distribution = tf.distributions.Bernoulli(
            logits=decoder(prior_sample_codes))
        prior_sample = prior_decoder_distribution.mean()
        tf.summary.image("prior_sample", tf.cast(prior_sample, tf.float32))

    loss = reconstruction_loss + embedding_loss + beta * commitment_loss

    tf.summary.scalar("losses/total_loss", loss)
    tf.summary.scalar("losses/embedding_loss", embedding_loss)
    tf.summary.scalar("losses/reconstruction_loss", reconstruction_loss)
    tf.summary.scalar("losses/commitment_loss", beta * commitment_loss)

    train_op = tf.train.AdamOptimizer().minimize(loss)
    summary_op = tf.summary.merge_all()

    run_training(train_op, summary_op, writer)