예제 #1
0
 def autoreg(i, x, vs, y, p):
     # i : ()              time step from 0 to t=len_tgt
     # x : (b, 1)          x_i
     # v : (b, t, dim)     attention values
     # y : (b, t, dim_tgt) logit over x one step ahead
     # p : (b, t)          predictions
     with tf.variable_scope('emb_tgt'):
         x = pos[i] + dropout(emb_tgt.embed(x))
     us = []
     for dec, v in zip(decode, vs):
         with tf.variable_scope('cache_v'):
             v = tf.concat((v, x), 1)
             us.append(v)
         x = dec(x, v, w, mask, dropout)
     x = logit(x)
     with tf.variable_scope('cache_y'):
         y = tf.concat((y, x), 1)
     if random:
         with tf.variable_scope('sample'):
             x = tf.multinomial(tf.squeeze(x, 1),
                                1,
                                output_dtype=tf.int32)
     else:
         x = tf.argmax(x, -1, output_type=tf.int32, name='argmax')
     with tf.variable_scope('cache_p'):
         p = tf.concat((p, x), 1)
     return i + 1, x, tuple(us), y, p
예제 #2
0
파일: model.py 프로젝트: ysmiraak/tau
 def body(i, q):
     j = i + 1
     x = pos[:,:j] + self.emb_tgt(q) # bdj <- bj
     x = self.decode(x, msk[:,:j,:j], w, self.mask_src, dropout) # bdj
     p = tf.expand_dims( # b1
         tf.argmax( # b
             self.emb_tgt( # bn
                 tf.squeeze( # bd
                     x[:,:,-1:] # bd1 <- bdj
                     , axis= -1))
             , axis= -1, output_type= tf.int32)
         , axis= -1)
     return j, tf.concat((q, p), axis= -1) # bk <- bj, b1
예제 #3
0
파일: model.py 프로젝트: ysmiraak/tau
def sinusoid(dim, time, freq= 1e-4, array= False):
    """returns a rank-2 tensor of shape `dim, time`, where each column
    corresponds to a time step and each row a sinusoid, with
    frequencies in a geometric progression from 1 to `freq`.

    """
    assert not dim % 2
    if array:
        a = (freq ** ((2 / dim) * np.arange(dim // 2))).reshape(-1, 1) @ (1 + np.arange(time).reshape(1, -1))
        return np.concatenate((np.sin(a), np.cos(a)), -1).reshape(dim, time)
    else:
        assert False # figure out a better way to do this
        a = tf.reshape(
            freq ** ((2 / dim) * tf.range(dim // 2, dtype= tf.float32))
            , (-1, 1)) @ tf.reshape(
                1 + tf.range(tf.to_float(time), dtype= tf.float32)
                , (1, -1))
        return tf.reshape(tf.concat((tf.sin(a), tf.cos(a)), axis= -1), (dim, time))
예제 #4
0
def sinusoid(time, dim, freq=1e-4, name='sinusoid', scale=True, array=False):
    """returns a rank-2 tensor of shape `time, dim`, where each row
    corresponds to a time step and each column a sinusoid, with
    frequencies in a geometric progression from 1 to `freq`.

    """
    assert not dim % 2
    if array:
        a = (freq**((2 / dim) * np.arange(dim // 2))).reshape(
            -1, 1) @ np.arange(time).reshape(1, -1)
        s = np.concatenate((np.sin(a), np.cos(a)), -1).reshape(dim, time)
        if scale: s *= dim**-0.5
        return s.T
    with tf.variable_scope(name):
        a = tf.reshape(
            freq**((2 / dim) * tf.range(dim // 2, dtype=tf.float32)),
            (-1, 1)) @ tf.reshape(
                tf.range(tf.cast(time, tf.float32), dtype=tf.float32), (1, -1))
        s = tf.reshape(tf.concat((tf.sin(a), tf.cos(a)), -1), (dim, time))
        if scale: s *= dim**-0.5
        return tf.transpose(s)
예제 #5
0
파일: model.py 프로젝트: argsim/argsim
def vAe(
        mode,
        src=None,
        tgt=None,
        # model spec
        dim_tgt=8192,
        dim_emb=512,
        dim_rep=1024,
        rnn_layers=3,
        bidirectional=True,
        bidir_stacked=True,
        attentive=False,
        logit_use_embed=True,
        # training spec
        accelerate=1e-4,
        learn_rate=1e-3,
        bos=2,
        eos=1):

    # dim_tgt : vocab size
    # dim_emb : model dimension
    # dim_rep : representation dimension
    #
    # unk=0 for word dropout

    assert mode in ('train', 'valid', 'infer')
    self = Record(bos=bos, eos=eos)

    with scope('step'):
        step = self.step = tf.train.get_or_create_global_step()
        rate = accelerate * tf.to_float(step)
        rate_keepwd = self.rate_keepwd = tf.sigmoid(rate)
        rate_anneal = self.rate_anneal = tf.tanh(rate)
        rate_update = self.rate_update = learn_rate / (tf.sqrt(rate) + 1.0)

    with scope('src'):
        src = self.src = placeholder(tf.int32, (None, None), src, 'src')
        src = tf.transpose(src)  # time major order
        src, msk_src, len_src = trim(src, eos)

    with scope('tgt'):
        tgt = self.tgt = placeholder(tf.int32, (None, None), tgt, 'tgt')
        tgt = tf.transpose(tgt)  # time major order
        tgt, msk_tgt, len_tgt = trim(tgt, eos)
        msk_tgt = tf.pad(msk_tgt, ((1, 0), (0, 0)), constant_values=True)
        # pads for decoder : lead=[bos]+tgt -> gold=tgt+[eos]
        lead, gold = tgt, tf.pad(tgt,
                                 paddings=((0, 1), (0, 0)),
                                 constant_values=eos)
        if 'train' == mode:
            lead *= tf.to_int32(
                tf.random_uniform(tf.shape(lead)) < rate_keepwd)
        lead = self.lead = tf.pad(lead,
                                  paddings=((1, 0), (0, 0)),
                                  constant_values=bos)

    # s : src length
    # t : tgt length plus one padding, either eos or bos
    # b : batch size
    #
    # len_src :  b  aka s
    # msk_src : sb  without padding
    # msk_tgt : tb  with eos
    #
    #    lead : tb  with bos
    #    gold : tb  with eos

    with scope('embed'):
        b = (6 / (dim_tgt / dim_emb + 1))**0.5
        embedding = tf.get_variable('embedding', (dim_tgt, dim_emb),
                                    initializer=tf.random_uniform_initializer(
                                        -b, b))
        emb_tgt = tf.gather(embedding, lead,
                            name='emb_tgt')  # (t, b) -> (t, b, dim_emb)
        emb_src = tf.gather(embedding, src,
                            name='emb_src')  # (s, b) -> (s, b, dim_emb)

    with scope('encode'):  # (s, b, dim_emb) -> (b, dim_emb)
        reverse = partial(tf.reverse_sequence,
                          seq_lengths=len_src,
                          seq_axis=0,
                          batch_axis=1)

        if bidirectional and bidir_stacked:
            for i in range(rnn_layers):
                with scope("rnn{}".format(i + 1)):
                    emb_fwd, _ = layer_rnn(1, dim_emb, name='fwd')(emb_src)
                    emb_bwd, _ = layer_rnn(1, dim_emb,
                                           name='bwd')(reverse(emb_src))
                    hs = emb_src = tf.concat((emb_fwd, reverse(emb_bwd)),
                                             axis=-1)

        elif bidirectional:
            with scope("rnn"):
                emb_fwd, _ = layer_rnn(rnn_layers, dim_emb,
                                       name='fwd')(emb_src)
                emb_bwd, _ = layer_rnn(rnn_layers, dim_emb,
                                       name='bwd')(reverse(emb_src))
            hs = tf.concat((emb_fwd, reverse(emb_bwd)), axis=-1)

        else:
            hs, _ = layer_rnn(rnn_layers, dim_emb, name='rnn')(emb_src)

        with scope('cata'):
            # extract the final states from the outputs: bd <- sbd, b2
            h = tf.gather_nd(
                hs,
                tf.stack(
                    (len_src - 1, tf.range(tf.size(len_src), dtype=tf.int32)),
                    axis=1))
            if attentive:  # todo fixme
                # the values are the outputs from all non-padding steps;
                # the queries are the final states;
                h = layer_nrm(h + tf.squeeze(  # bd <- bd1
                    attention(  # bd1 <- bd1, bds, b1s
                        tf.expand_dims(h, axis=2),  # query: bd1 <- bd
                        tf.transpose(hs, (1, 2, 0)),  # value: bds <- sbd
                        tf.log(
                            tf.to_float(  # -inf,0  mask: b1s <- sb <- bs
                                tf.expand_dims(tf.transpose(msk_src),
                                               axis=1))),
                        int(h.shape[-1])),
                    2))

    with scope('latent'):  # (b, dim_emb) -> (b, dim_rep) -> (b, dim_emb)
        # h = layer_aff(h, dim_emb, name='in')
        mu = self.mu = layer_aff(h, dim_rep, name='mu')
        lv = self.lv = layer_aff(h, dim_rep, name='lv')
        with scope('z'):
            h = mu
            if 'train' == mode:
                h += tf.exp(0.5 * lv) * tf.random_normal(shape=tf.shape(lv))
            self.z = h
        h = layer_aff(h, dim_emb, name='ex')

    with scope('decode'):  # (b, dim_emb) -> (t, b, dim_emb) -> (?, dim_emb)
        h = self.state_in = tf.stack((h, ) * rnn_layers)
        h, _ = _, (self.state_ex, ) = layer_rnn(rnn_layers,
                                                dim_emb,
                                                name='rnn')(
                                                    emb_tgt,
                                                    initial_state=(h, ))
        if 'infer' != mode: h = tf.boolean_mask(h, msk_tgt)
        h = layer_aff(h, dim_emb, name='out')

    with scope('logits'):  # (?, dim_emb) -> (?, dim_tgt)
        if logit_use_embed:
            logits = self.logits = tf.tensordot(h, (dim_emb**-0.5) *
                                                tf.transpose(embedding), 1)
        else:
            logits = self.logits = layer_aff(h, dim_tgt)

    with scope('prob'):
        prob = self.prob = tf.nn.softmax(logits)
    with scope('pred'):
        pred = self.pred = tf.argmax(logits, -1, output_type=tf.int32)

    if 'infer' != mode:
        labels = tf.boolean_mask(gold, msk_tgt, name='labels')
        with scope('errt'):
            errt_samp = self.errt_samp = tf.to_float(tf.not_equal(
                labels, pred))
            errt = self.errt = tf.reduce_mean(errt_samp)
        with scope('loss'):
            with scope('loss_gen'):
                loss_gen_samp = self.loss_gen_samp = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=labels, logits=logits)
                loss_gen = self.loss_gen = tf.reduce_mean(loss_gen_samp)
            with scope('loss_kld'):
                loss_kld_samp = self.loss_kld_samp = 0.5 * (
                    tf.square(mu) + tf.exp(lv) - lv - 1.0)
                loss_kld = self.loss_kld = tf.reduce_mean(loss_kld_samp)
            loss = self.loss = rate_anneal * loss_kld + loss_gen

    if 'train' == mode:
        with scope('train'):
            train_step = self.train_step = tf.train.AdamOptimizer(
                rate_update).minimize(loss, step)

    return self
예제 #6
0
def model(mode,
          src_dwh,
          tgt_dwh,
          src_idx=None,
          len_src=None,
          tgt_img=None,
          tgt_idx=None,
          len_tgt=None,
          num_layers=3,
          num_units=512,
          learn_rate=1e-3,
          decay_rate=1e-2,
          dropout=0.1):
    assert mode in ('train', 'valid', 'infer')
    self = Record()

    src_d, src_w, src_h = src_dwh
    tgt_d, tgt_w, tgt_h = tgt_dwh

    with scope('source'):
        # input nodes
        src_idx = self.src_idx = placeholder(tf.int32, (None, None), src_idx,
                                             'src_idx')  # n s
        len_src = self.len_src = placeholder(tf.int32, (None, ), len_src,
                                             'len_src')  # n

        # time major order
        src_idx = tf.transpose(src_idx, (1, 0))  # s n
        emb_src = tf.one_hot(src_idx, src_d)  # s n v

        for i in range(num_layers):
            with scope("rnn{}".format(i + 1)):
                emb_fwd, _ = tf.contrib.cudnn_rnn.CudnnGRU(
                    1, num_units, dropout=dropout,
                    name='fwd')(emb_src, training='train' == mode)
                emb_bwd, _ = tf.contrib.cudnn_rnn.CudnnGRU(
                    1, num_units, dropout=dropout,
                    name='bwd')(tf.reverse_sequence(emb_src,
                                                    len_src,
                                                    seq_axis=0,
                                                    batch_axis=1),
                                training='train' == mode)
            emb_src = tf.concat(
                (emb_fwd,
                 tf.reverse_sequence(
                     emb_bwd, len_src, seq_axis=0, batch_axis=1)),
                axis=-1)
        # emb_src = tf.layers.dense(emb_src, num_units, name= 'reduce_concat') # s n d
        emb_src = self.emb_src = tf.transpose(emb_src, (1, 2, 0))  # n d s

    with scope('target'):
        # input nodes
        tgt_img = self.tgt_img = placeholder(tf.uint8,
                                             (None, None, tgt_h, tgt_w),
                                             tgt_img, 'tgt_img')  # n t h w
        tgt_idx = self.tgt_idx = placeholder(tf.int32, (None, None), tgt_idx,
                                             'tgt_idx')  # n t
        len_tgt = self.len_tgt = placeholder(tf.int32, (None, ), len_tgt,
                                             'len_tgt')  # n

        # time major order
        tgt_idx = tf.transpose(tgt_idx)  # t n
        tgt_img = tf.transpose(tgt_img, (1, 0, 2, 3))  # t n h w
        tgt_img = flatten(tgt_img, 2, 3)  # t n hw

        # normalize pixels to binary
        tgt_img = tf.to_float(tgt_img) / 255.0
        # tgt_img = tf.round(tgt_img)
        # todo consider adding noise

        # causal padding
        fire = self.fire = tf.pad(tgt_img, ((1, 0), (0, 0), (0, 0)),
                                  constant_values=0.0)
        true = self.true = tf.pad(tgt_img, ((0, 1), (0, 0), (0, 0)),
                                  constant_values=1.0)
        tidx = self.tidx = tf.pad(tgt_idx, ((0, 1), (0, 0)), constant_values=1)
        mask_tgt = tf.transpose(tf.sequence_mask(len_tgt + 1))  # t n

    with scope('decode'):
        # needs to get input from latent space to do attention or some shit
        decoder = self.decoder = tf.contrib.cudnn_rnn.CudnnGRU(num_layers,
                                                               num_units,
                                                               dropout=dropout)
        state_in = self.state_in = tf.zeros(
            (num_layers, tf.shape(fire)[1], num_units))
        x, _ = _, (self.state_ex, ) = decoder(fire,
                                              initial_state=(state_in, ),
                                              training='train' == mode)
        # transform mask to -inf and 0 in order to simply sum for whatever the f**k happens next
        mask = tf.log(tf.sequence_mask(len_src, dtype=tf.float32))  # n s
        mask = tf.expand_dims(mask, 1)  # n 1 s
        # multi-head scaled dot-product attention
        x = tf.transpose(x, (1, 2, 0))  # t n d ---> n d t
        attn = Attention(num_units, num_units, 2 * num_units)(x, emb_src, mask)
        if 'train' == mode: attn = tf.nn.dropout(attn, 1 - dropout)
        x = Normalize(num_units)(x + attn)
        x = tf.transpose(x, (2, 0, 1))  # n d t ---> t n d

    if 'infer' != mode:
        x = tf.boolean_mask(x, mask_tgt)
        true = tf.boolean_mask(true, mask_tgt)
        tidx = tf.boolean_mask(tidx, mask_tgt)

    with scope('output'):
        y = tf.layers.dense(x, tgt_h * tgt_w, name='dense_img')
        z = tf.layers.dense(x, tgt_d, name='logit_idx')
        pred = self.pred = tf.clip_by_value(y, 0.0, 1.0)
        prob = self.prob = tf.nn.softmax(z)
        pidx = self.pidx = tf.argmax(z, axis=-1, output_type=tf.int32)

    with scope('losses'):
        diff = true - pred
        mae = self.mae = tf.reduce_mean(tf.abs(diff), axis=-1)
        mse = self.mse = tf.reduce_mean(tf.square(diff), axis=-1)
        xid = self.xid = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=z, labels=tidx)
        err = self.err = tf.not_equal(tidx, pidx)
        loss = tf.reduce_mean(xid)

    with scope('update'):
        step = self.step = tf.train.get_or_create_global_step()
        lr = self.lr = learn_rate / (1.0 +
                                     decay_rate * tf.sqrt(tf.to_float(step)))
        if 'train' == mode:
            down = self.down = tf.train.AdamOptimizer(lr).minimize(loss, step)

    return self
def train(anomaly_class=8,
          dataset="cifar",
          n_dis=1,
          epochs=25,
          dim_btlnk=32,
          batch_size=64,
          loss="mean",
          context_weight=1,
          dim_d=64,
          dim_g=64,
          extra_layers=0,
          gpu="0"):

    #set gpu
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu

    path_log = f"/cache/tensorboard-logdir/{dataset}"
    path_ckpt = "/project/multi-discriminator-gan/ckpt"
    path_data = "/project/multi-discriminator-gan/data"

    #reset graphs and fix seeds
    tf.reset_default_graph()
    if 'sess' in globals(): sess.close()
    rand = RandomState(0)
    tf.set_random_seed(0)

    #load data
    if dataset == "ucsd1":
        x_train = np.load("./data/ucsd1_train_x.npz")["arr_0"] / 255
        y_train = np.load("./data/ucsd1_train_y.npz")["arr_0"]
        x_test = np.load("./data/ucsd1_test_x.npz")["arr_0"] / 255
        y_test = np.load("./data/ucsd1_test_y.npz")["arr_0"]

    elif dataset == "uscd2":
        x_train = np.load("./data/ucsd2_train_x.npz")["arr_0"]
        y_train = np.load("./data/ucsd2_train_y.npz")["arr_0"]
        x_test = np.load("./data/ucsd2_test_x.npz")["arr_0"]
        y_test = np.load("./data/ucsd2_test_y.npz")["arr_0"]

    else:
        if dataset == "mnist":
            (train_images, train_labels), (
                test_images,
                test_labels) = tf.keras.datasets.mnist.load_data()
            train_images = resize_images(train_images)
            test_images = resize_images(test_images)
        else:
            (train_images, train_labels), (
                test_images,
                test_labels) = tf.keras.datasets.cifar10.load_data()
            train_labels = np.reshape(train_labels, len(train_labels))
            test_labels = np.reshape(test_labels, len(test_labels))

        inlier = train_images[train_labels != anomaly_class]
        #data_size = prod(inlier[0].sha
        x_train = inlier / 255
        #x_train = np.reshape(inlier, (len(inlier), data_size))/255
        #y_train = train_labels[train_labels!=anomaly_class]
        y_train = np.zeros(len(x_train), dtype=np.int8)  # dummy
        outlier = train_images[train_labels == anomaly_class]
        x_test = np.concatenate([outlier, test_images]) / 255
        #x_test = np.reshape(np.concatenate([outlier, test_images])
        #                    ,(len(outlier)+len(test_images), data_size))/255
        y_test = np.concatenate(
            [train_labels[train_labels == anomaly_class], test_labels])
        y_test = [0 if y != anomaly_class else 1 for y in y_test]
        x_test, y_test = unison_shfl(x_test, np.array(y_test))

    img_size_x = x_train[0].shape[0]
    img_size_y = x_train[0].shape[1]
    channel = x_train[0].shape[-1]
    trial = f"{dataset}_{loss}_dis{n_dis}_{anomaly_class}_w{context_weight}_btlnk{dim_btlnk}_d{dim_d}_g{dim_g}e{extra_layers}"

    # data pipeline
    batch_fn = lambda: batch2(x_train, y_train, batch_size)
    x, y = pipe(batch_fn, (tf.float32, tf.float32), prefetch=4)
    #z = tf.random_normal((batch_size, z_dim))

    # load graph
    mg_gan = MG_GAN.new(img_size_x,
                        channel,
                        dim_btlnk,
                        dim_d,
                        dim_g,
                        n_dis,
                        extra_layers=0)
    model = MG_GAN.build(mg_gan, x, y, context_weight, loss)

    # start session, initialize variables

    sess = tf.InteractiveSession()
    saver = tf.train.Saver()

    wrtr = tf.summary.FileWriter(pform(path_log, trial))
    wrtr.add_graph(sess.graph)

    ### if load pretrained model
    # pretrain = "modelname"
    #saver.restore(sess, pform(path_ckpt, pretrain))
    ### else:
    auc_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope='AUC')
    init = tf.group(tf.global_variables_initializer(),
                    tf.variables_initializer(var_list=auc_vars))
    sess.run(init)

    #if "ucsd" in dataset:
    summary_test = tf.summary.merge([
        tf.summary.scalar('g_loss', model.g_loss),
        tf.summary.scalar("lambda", model.lam),
        tf.summary.scalar("gl_rec", model.gl_rec),
        tf.summary.scalar("gl_adv", model.gl_adv),
        tf.summary.scalar("gl_lam", model.gl_lam),
        tf.summary.scalar('d_loss_mean', model.d_loss_mean),
        tf.summary.scalar('d_max', model.d_max)
        #, tf.summary.scalar('d_loss', model.d_loss)
        ,
        tf.summary.scalar("AUC_gx", model.auc_gx)
    ])
    if dataset == "ucsd1":
        summary_images = tf.summary.merge(
            (tf.summary.image("gx", model.gx, max_outputs=8),
             tf.summary.image("x", model.x, max_outputs=8),
             tf.summary.image(
                 'gx400',
                 spread_image(tf.concat([model.gx, model.x], axis=1), 8, 2,
                              img_size_x, img_size_y, channel))))
    else:
        summary_images = tf.summary.merge(
            (tf.summary.image("gx", model.gx, max_outputs=8),
             tf.summary.image(
                 'gx400',
                 spread_image(model.gx[:400], 20, 20, img_size_x, img_size_y,
                              channel)),
             tf.summary.image("x", model.x, max_outputs=8)))

    if n_dis > 1:
        d_wrtr = {
            i: tf.summary.FileWriter(pform(path_log, trial + f"d{i}"))
            for i in range(n_dis)
        }
        summary_discr = {
            i: tf.summary.scalar('d_loss_multi', model.d_loss[i])
            for i in range(n_dis)
        }

    def summ(step):
        fetches = model.g_loss, model.lam, model.d_loss_mean, model.auc_gx
        results = map(
            np.mean,
            zip(*(sess.run(fetches, {
                model['x']: x_test[i:j],
                model['y']: y_test[i:j]
            }) for i, j in partition(len(x_test), batch_size, discard=False))))
        results = list(results)
        wrtr.add_summary(sess.run(summary_test, dict(zip(fetches, results))),
                         step)

        if dataset == "ucsd1":
            # bike, skateboard, grasswalk, shopping cart, car, normal, normal, grass
            wrtr.add_summary(
                sess.run(
                    summary_images, {
                        model.x:
                        x_test[[990, 1851, 2140, 2500, 2780, 2880, 3380, 3580]]
                    }), step)
        else:
            wrtr.add_summary(sess.run(summary_images, {model.x: x_test}), step)
        wrtr.flush()

    def summ_discr(step):
        fetches = model.d_loss
        results = map(
            np.mean,
            zip(*(sess.run(fetches, {
                model['x']: x_test[i:j],
                model['y']: y_test[i:j]
            }) for i, j in partition(len(x_test), batch_size, discard=False))))
        results = list(results)
        if n_dis > 1:  # put all losses of the discriminators in one plot
            for i in range(n_dis):
                d_wrtr[i].add_summary(
                    sess.run(summary_discr[i], dict(zip(fetches, results))),
                    step)
                #d_wrtr[i].add_summary(sess.run(summary_discr[i], dict([(fetches[i], results[i])])), step)
                d_wrtr[i].flush()

    #def log(step
    #        , wrtr= wrtr
    #        , log = tf.summary.merge([tf.summary.scalar('g_loss', model.g_loss)
    #                                  , tf.summary.scalar('d_loss', tf.reduce_mean(model.d_loss))
    #                                  , tf.summary.scalar("lambda", model.lam)
    #                                  , tf.summary.image("gx", model.gx, max_outputs=5)
    #                                  , tf.summary.image('gx400', spread_image(model.gx[:400], 20,20, img_size, img_size, channel))
    #                                  #, tf.summary.scalar("AUC_dgx", model.auc_dgx)
    #                                  #, tf.summary.scalar("AUC_dx", model.auc_dx)
    #                                  , tf.summary.scalar("AUC_gx", model.auc_gx)])
    #        , y= y_test
    #        , x= x_test):
    #    wrtr.add_summary(sess.run(log, {model["x"]:x
    #                                    , model["y"]:y})
    #                     , step)
    #    wrtr.flush()

    steps_per_epoch = len(x_train) // batch_size - 1
    for epoch in tqdm(range(epochs)):
        for i in range(steps_per_epoch):
            #sess.run(model["train_step"])
            sess.run(model['d_step'])
            sess.run(model['g_step'])
        # tensorboard writer
        #if "ucsd" in dataset:
        summ(sess.run(model["step"]) // steps_per_epoch)
        #else:
        #    log(sess.run(model["step"])//steps_per_epoch)
        if n_dis > 1:
            summ_discr(sess.run(model["step"]) // steps_per_epoch)

    saver.save(sess, pform(path_ckpt, trial), write_meta_graph=False)