Esempio n. 1
0
def main(dataname):

    model = CAE()
    name = "CAE"

    model_dict = torch.load('models/CAE_model', map_location='cpu')
    model.load_state_dict(model_dict)
    model.eval

    EPOCH = 1
    BATCH_SIZE_TRAIN = 49950

    # dataname = "demonstrations/demo_00_02.pkl"
    save_model_path = "models/" + name + "_model"
    best_model_path = "models/" + name + "_best_model"

    train_data = MotionData(dataname)
    train_set = DataLoader(dataset=train_data,
                           batch_size=BATCH_SIZE_TRAIN,
                           shuffle=True)

    for epoch in range(EPOCH):
        epoch_loss = 0.0
        for batch, x in enumerate(train_set):
            # loss = model(x)
            z = model.encoder(x).tolist()
            print(len(z))
        # print(epoch, loss.item())
    data = np.asarray(z)
    return data
Esempio n. 2
0
class Model(object):
    def __init__(self):
        self.model = CAE()
        model_dict = torch.load('models/CAE_model', map_location='cpu')
        self.model.load_state_dict(model_dict)
        self.model.eval

    def decoder(self, img, s, z):
        img = img / 128.0 - 1.0
        img = np.transpose(img, (2, 0, 1))
        img = torch.FloatTensor([img])
        s = torch.FloatTensor([s])
        z = torch.FloatTensor([z])
        context = (img, s, z)
        a_tensor = self.model.decoder(context)
        a_numpy = a_tensor.detach().numpy()[0]
        return list(a_numpy)
Esempio n. 3
0
def run_forrest_run(dataset_list, activation_list, modelname):
    for dataset_name in dataset_list:
        for name in activation_list:
            for model in modelname:
                if model == "DNN":
                    dataset = Datasets()
                    if (dataset_name == 'MNIST'):
                        x_train, x_test, y_train, y_test = dataset.get_mnist(
                            "DNN")
                        num_classes = dataset.num_classes
                        input_shape = dataset.input_shape
                    elif (dataset_name == 'Fashion-MNIST'):
                        x_train, x_test, y_train, y_test = dataset.get_fashion_mnist(
                            "DNN")
                        num_classes = dataset.num_classes
                        input_shape = dataset.input_shape
                    dnn = DNN(name)
                    score, history = dnn.run_model(input_shape, x_train,
                                                   x_test, y_train, y_test, 1)

                else:
                    dataset = Datasets()
                    if (dataset_name == 'MNIST'):
                        x_train, x_test, y_train, y_test = dataset.get_mnist(
                            "CNN")
                    elif (dataset_name == 'Fashion-MNIST'):
                        x_train, x_test, y_train, y_test = dataset.get_fashion_mnist(
                            "CNN")
                    num_classes = dataset.num_classes
                    input_shape = dataset.input_shape
                    if model == "CNN":
                        cnn = CNN(name)
                        score, history = cnn.run_model(input_shape, x_train,
                                                       x_test, y_train, y_test)
                    elif model == "CAE":
                        cae = CAE(name)
                        score, history = cae.run_model(input_shape, x_train,
                                                       x_test, y_train, y_test)
                    score, history = cnn.run_model(input_shape, x_train,
                                                   x_test, y_train, y_test)
                plot_model(history, name, model, dataset_name)
Esempio n. 4
0
def get_model(name, device):
    """
    Returns required classifier and autoencoder
    :param name:
    :return: Autoencoder, Classifier
    """
    if name == 'lenet':
        model = LeNet(in_channels=channels).to(device)
    elif name == 'alexnet':
        model = AlexNet(channels=channels, num_classes=10).to(device)
    elif name == 'vgg':
        model = VGG(in_channels=channels, num_classes=10).to(device)

    autoencoder = CAE(in_channels=channels).to(device)
    return model, autoencoder
Esempio n. 5
0
class Model(object):

    def __init__(self):
        self.model = CAE()
        model_dict = torch.load('models/CAE_best_model', map_location='cpu')
        self.model.load_state_dict(model_dict)
        self.model.eval

    def decoder(self, z, s):
        if abs(z[0][0]) < 0.01:
             return [0.0] * 6
        # z = np.asarray([z])
        z = np.asarray(z)
        # print(s.shape)
        z_tensor = torch.FloatTensor(np.concatenate((z,s),axis=1))
        # print(z_tensor.shape)
        a_tensor = self.model.decoder(z_tensor)
        return a_tensor.tolist()
    def encoder(self, a, s):
        
        x = np.concatenate((a,s),axis=1)
        x_tensor = torch.FloatTensor(x)
        z = self.model.encoder(x_tensor)
        return z.tolist()
Esempio n. 6
0
def test_gpu_cpu_3():
    input = np.random.normal(0, 0.1, size=[1, 30, 300, 1])

    with tf.device("/cpu:0"):
        np.random.seed(1234)
        model_cpu = CAE(const.INPUT_SHAPE)

    with tf.device('/gpu:0'):
        np.random.seed(1234)
        model_gpu = CAE(const.INPUT_SHAPE)

    results_cpu = model_cpu.predict([input])
    results_gpu = model_gpu.predict([input])

    print(np.array_equal(results_cpu, results_gpu))
Esempio n. 7
0
sys.path.append('./src')

from load_topos import load_topos
from models import CAE
from training import pretrain_cae
import numpy as np

import wandb
from wandb.keras import WandbCallback
from time import gmtime, strftime

date = strftime("%Y-%m-%d %H:%M:%S", gmtime())
wandb.init(project='deep-artefact',
           group='001_pretrain_cae',
           name=f'001_pretrain_cae_{date}')

x = load_topos('./data/dummy.mat')
cae = CAE(input_shape=x.shape[1:], filters=[33, 64, 128, 61])
history = pretrain_cae(x,
                       cae,
                       batch_size=256,
                       epochs=5,
                       optimizer='adam',
                       save_dir='./data/',
                       callbacks=[WandbCallback()])

loss = history.history['loss']
epochs = np.arange(1, len(loss) + 1)

np.save('./data/cae_loss.npy', np.array([loss, epochs]))
Esempio n. 8
0
def extract_features(path_boxes_np, CAE_model_path, args):
    f_imgs, g_imgs, b_imgs = util.CAE_dataset_feed_dict(
        prefix, path_boxes_np, args.dataset)
    print('dataset loaded!')
    iters = np.load(path_boxes_np).__len__()

    former_batch = tf.placeholder(dtype=tf.float32,
                                  shape=[1, 64, 64, 1],
                                  name='former_batch')
    gray_batch = tf.placeholder(dtype=tf.float32,
                                shape=[1, 64, 64, 1],
                                name='gray_batch')
    back_batch = tf.placeholder(dtype=tf.float32,
                                shape=[1, 64, 64, 1],
                                name='back_batch')

    grad1_x, grad1_y = tf.image.image_gradients(former_batch)
    # grad2_x,grad2_y=tf.image.image_gradients(gray_batch)
    grad3_x, grad3_y = tf.image.image_gradients(back_batch)

    grad_dis_1 = tf.sqrt(tf.square(grad1_x) + tf.square(grad1_y))
    grad_dis_2 = tf.sqrt(tf.square(grad3_x) + tf.square(grad3_y))

    former_feat = CAE.CAE_encoder(grad_dis_1,
                                  'former',
                                  bn=args.bn,
                                  training=False)
    gray_feat = CAE.CAE_encoder(gray_batch, 'gray', bn=args.bn, training=False)
    back_feat = CAE.CAE_encoder(grad_dis_2, 'back', bn=args.bn, training=False)
    # [batch_size,3072]
    feat = tf.concat([
        tf.layers.flatten(former_feat),
        tf.layers.flatten(gray_feat),
        tf.layers.flatten(back_feat)
    ],
                     axis=1)

    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 scope='former_encoder')
    var_list.extend(
        tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                          scope='gray_encoder'))
    var_list.extend(
        tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                          scope='back_encoder'))

    g_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                               scope='former_encoder')
    g_list.extend(
        tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='gray_encoder'))
    g_list.extend(
        tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='back_encoder'))
    bn_list = [
        g for g in g_list
        if 'moving_mean' in g.name or 'moving_variance' in g.name
    ]
    var_list += bn_list

    restorer = tf.train.Saver(var_list=var_list)
    data = []
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        if args.bn:
            restorer.restore(sess, CAE_model_path + '_bn')
        else:
            restorer.restore(sess, CAE_model_path)
        for i in range(iters):
            feed_dict = {
                former_batch: np.expand_dims(f_imgs[i], 0),
                gray_batch: np.expand_dims(g_imgs[i], 0),
                back_batch: np.expand_dims(b_imgs[i], 0)
            }
            result = sess.run(feat, feed_dict=feed_dict)

            if args.norm == 0:
                _temp = result[0]
            else:
                _temp = util.norm_(result[0], l=args.norm)[0]

            if args.class_add:
                c_onehot_embedding = np.zeros(90, dtype=np.float32)
                c_onehot_embedding[class_indexes[i] - 1] = 1
                _temp = np.concatenate((_temp, c_onehot_embedding), axis=0)

            data.append(_temp)
        data = np.array(data)
        sess.close()

    return data
Esempio n. 9
0
def train_CAE(path_boxes_np, args):
    epoch_len = len(np.load(path_boxes_np))
    f_imgs, g_imgs, b_imgs = util.CAE_dataset_feed_dict(
        prefix, path_boxes_np, dataset_name=args.dataset)
    #former_batch,gray_batch,back_batch=util.CAE_dataset(path_boxes_np,args.dataset,epochs,batch_size)
    former_batch = tf.placeholder(dtype=tf.float32,
                                  shape=[batch_size, 64, 64, 1],
                                  name='former_batch')
    gray_batch = tf.placeholder(dtype=tf.float32,
                                shape=[batch_size, 64, 64, 1],
                                name='gray_batch')
    back_batch = tf.placeholder(dtype=tf.float32,
                                shape=[batch_size, 64, 64, 1],
                                name='back_batch')

    grad1_x, grad1_y = tf.image.image_gradients(former_batch)
    # grad2_x,grad2_y=tf.image.image_gradients(gray_batch)
    grad3_x, grad3_y = tf.image.image_gradients(back_batch)

    grad_dis_1 = tf.sqrt(tf.square(grad1_x) + tf.square(grad1_y))
    grad_dis_2 = tf.sqrt(tf.square(grad3_x) + tf.square(grad3_y))

    former_outputs = CAE.CAE(grad_dis_1, 'former', bn=args.bn, training=True)
    gray_outputs = CAE.CAE(gray_batch, 'gray', bn=args.bn, training=True)
    back_outputs = CAE.CAE(grad_dis_2, 'back', bn=args.bn, training=True)

    former_loss = CAE.pixel_wise_L2_loss(former_outputs, grad_dis_1)
    gray_loss = CAE.pixel_wise_L2_loss(gray_outputs, gray_batch)
    back_loss = CAE.pixel_wise_L2_loss(back_outputs, grad_dis_2)

    global_step = tf.Variable(0, dtype=tf.int32, trainable=False)
    global_step_a = tf.Variable(0, dtype=tf.int32, trainable=False)
    global_step_b = tf.Variable(0, dtype=tf.int32, trainable=False)

    lr_decay_epochs[0] = int(epoch_len // batch_size) * lr_decay_epochs[0]

    lr = tf.train.piecewise_constant(global_step,
                                     boundaries=lr_decay_epochs,
                                     values=learning_rate)

    former_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope='former_')
    gray_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope='gray_')
    back_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope='back_')
    # print(former_vars)
    if args.weight_reg != 0:
        former_loss = former_loss + args.weight_reg * weiht_regualized_loss(
            former_vars)
        gray_loss = gray_loss + args.weight_reg * weiht_regualized_loss(
            gray_vars)
        back_loss = back_loss + args.weight_reg * weiht_regualized_loss(
            back_vars)

    former_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(
        former_loss, var_list=former_vars, global_step=global_step)
    gray_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(
        gray_loss, var_list=gray_vars, global_step=global_step_a)
    back_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(
        back_loss, var_list=back_vars, global_step=global_step_b)

    step = 0
    if not args.bn:
        writer = tf.summary.FileWriter(logdir=summary_save_path_pre +
                                       args.dataset)
    else:
        writer = tf.summary.FileWriter(logdir=summary_save_path_pre +
                                       args.dataset + '_bn')

    tf.summary.scalar('loss/former_loss', former_loss)
    tf.summary.scalar('loss/gray_loss', gray_loss)
    tf.summary.scalar('loss/back_loss', back_loss)
    tf.summary.image('inputs/former', grad_dis_1)
    tf.summary.image('inputs/gray', gray_batch)
    tf.summary.image('inputs/back', grad_dis_2)
    tf.summary.image('outputs/former', former_outputs)
    tf.summary.image('outputs/gray', gray_outputs)
    tf.summary.image('outputs/back', back_outputs)
    summary_op = tf.summary.merge_all()

    saver = tf.train.Saver(var_list=tf.global_variables())
    indices = list(range(epoch_len))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(epochs):
            random.shuffle(indices)
            for i in range(epoch_len // batch_size):
                feed_dict = {
                    former_batch: [
                        f_imgs[d]
                        for d in indices[i * batch_size:(i + 1) * batch_size]
                    ],
                    gray_batch: [
                        g_imgs[d]
                        for d in indices[i * batch_size:(i + 1) * batch_size]
                    ],
                    back_batch: [
                        b_imgs[d]
                        for d in indices[i * batch_size:(i + 1) * batch_size]
                    ]
                }
                step, _lr, _, _, _, _former_loss, _gray_loss, _back_loss = sess.run(
                    [
                        global_step, lr, former_op, gray_op, back_op,
                        former_loss, gray_loss, back_loss
                    ],
                    feed_dict=feed_dict)
                if step % 10 == 0:
                    print('At step {}'.format(step))
                    print('\tLearning Rate {:.4f}'.format(_lr))
                    print('\tFormer Loss {:.4f}'.format(_former_loss))
                    print('\tGray Loss {:.4f}'.format(_gray_loss))
                    print('\tBack Loss {:.4f}'.format(_back_loss))

                if step % 50 == 0:
                    _summary = sess.run(summary_op, feed_dict=feed_dict)
                    writer.add_summary(_summary, global_step=step)
        if not args.bn:
            saver.save(sess, model_save_path_pre + args.dataset)
        else:
            saver.save(sess, model_save_path_pre + args.dataset + '_bn')

        print('train finished!')
        sess.close()
Esempio n. 10
0
def train(opts):

    device = torch.device("cuda" if use_cuda else "cpu")

    if opts.arch == 'small':
        channels = [32, 32, 32, 10]
    elif opts.arch == 'large':
        channels = [256, 128, 64, 32]
    else:
        raise NotImplementedError('Unknown model architecture')

    if opts.mode == 'train_mnist':
        train_loader, valid_loader = get_mnist_loaders(opts.data_dir,
                                                       opts.bsize,
                                                       opts.nworkers,
                                                       opts.sigma, opts.alpha)
        model = CAE(1, 10, 28, opts.n_prototypes, opts.decoder_arch, channels)
    elif opts.mode == 'train_cifar':
        train_loader, valid_loader = get_cifar_loaders(opts.data_dir,
                                                       opts.bsize,
                                                       opts.nworkers,
                                                       opts.sigma, opts.alpha)
        model = CAE(3, 10, 32, opts.n_prototypes, opts.decoder_arch, channels)
    elif opts.mode == 'train_fmnist':
        train_loader, valid_loader = get_fmnist_loaders(
            opts.data_dir, opts.bsize, opts.nworkers, opts.sigma, opts.alpha)
        model = CAE(1, 10, 28, opts.n_prototypes, opts.decoder_arch, channels)
    else:
        raise NotImplementedError('Unknown train mode')

    if opts.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opts.lr,
                                     weight_decay=opts.wd)
    else:
        raise NotImplementedError("Unknown optim type")
    criterion = nn.CrossEntropyLoss()

    start_n_iter = 0
    # for choosing the best model
    best_val_acc = 0.0

    model_path = os.path.join(opts.save_path, 'model_latest.net')
    if opts.resume and os.path.exists(model_path):
        # restoring training from save_state
        print('====> Resuming training from previous checkpoint')
        save_state = torch.load(model_path, map_location='cpu')
        model.load_state_dict(save_state['state_dict'])
        start_n_iter = save_state['n_iter']
        best_val_acc = save_state['best_val_acc']
        opts = save_state['opts']
        opts.start_epoch = save_state['epoch'] + 1

    model = model.to(device)

    # for logging
    logger = TensorboardLogger(opts.start_epoch, opts.log_iter, opts.log_dir)
    logger.set(['acc', 'loss', 'loss_class', 'loss_ae', 'loss_r1', 'loss_r2'])
    logger.n_iter = start_n_iter

    for epoch in range(opts.start_epoch, opts.epochs):
        model.train()
        logger.step()
        valid_sample = torch.stack([
            valid_loader.dataset[i][0]
            for i in random.sample(range(len(valid_loader.dataset)), 10)
        ]).to(device)

        for batch_idx, (data, target) in enumerate(train_loader):
            acc, loss, class_error, ae_error, error_1, error_2 = run_iter(
                opts, data, target, model, criterion, device)

            # optimizer step
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), opts.max_norm)
            optimizer.step()

            logger.update(acc, loss, class_error, ae_error, error_1, error_2)

        val_loss, val_acc, val_class_error, val_ae_error, val_error_1, val_error_2, time_taken = evaluate(
            opts, model, valid_loader, criterion, device)
        # log the validation losses
        logger.log_valid(time_taken, val_acc, val_loss, val_class_error,
                         val_ae_error, val_error_1, val_error_2)
        print('')

        # Save the model to disk
        if val_acc >= best_val_acc:
            best_val_acc = val_acc
            save_state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'n_iter': logger.n_iter,
                'opts': opts,
                'val_acc': val_acc,
                'best_val_acc': best_val_acc
            }
            model_path = os.path.join(opts.save_path, 'model_best.net')
            torch.save(save_state, model_path)
            prototypes = model.save_prototypes(opts.save_path,
                                               'prototypes_best.png')
            x = torchvision.utils.make_grid(prototypes, nrow=10, pad_value=1.0)
            logger.writer.add_image('Prototypes (best)', x, epoch)

        save_state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'n_iter': logger.n_iter,
            'opts': opts,
            'val_acc': val_acc,
            'best_val_acc': best_val_acc
        }
        model_path = os.path.join(opts.save_path, 'model_latest.net')
        torch.save(save_state, model_path)
        prototypes = model.save_prototypes(opts.save_path,
                                           'prototypes_latest.png')
        x = torchvision.utils.make_grid(prototypes, nrow=10, pad_value=1.0)
        logger.writer.add_image('Prototypes (latest)', x, epoch)
        ae_samples = model.get_decoded_pairs_grid(valid_sample)
        logger.writer.add_image('AE_samples_latest', ae_samples, epoch)
Esempio n. 11
0

if __name__ == "__main__":
    data_dir = 'data/coco2017'
    log_dir = 'logs'
    input_image_size = (256, 256, 3)
    batch_size = 10
    latent_dim = 32
    optimizer = tf.keras.optimizers.Adam(1e-3)

    # Initialize and compile models
    incept_model = InceptionV3(include_top=False,
                               pooling='avg',
                               input_shape=input_image_size)
    ae_model = AE(latent_dim, input_image_size)
    cae_model = CAE(latent_dim, input_image_size)
    vae_model = VAE(latent_dim, input_image_size)
    cvae_model = CVAE(latent_dim, input_image_size)
    memcae_model = MemCAE(latent_dim, True, input_image_size, batch_size, 500,
                          optimizer)

    for classes in [['cat']]:
        # Load and augment training data
        ds_train = dataloader(classes, data_dir, input_image_size, batch_size,
                              'train2019')
        ds_val = dataloader(classes, data_dir, input_image_size, batch_size,
                            'val2019')
        class_label = classes[0] if len(classes) == 1 else "similar"

        # Train each model for comparison
        for m in [memcae_model]:
Esempio n. 12
0
 def __init__(self):
     self.model = CAE()
     model_dict = torch.load('models/CAE_model', map_location='cpu')
     self.model.load_state_dict(model_dict)
     self.model.eval
Esempio n. 13
0
def train_CAE(path_boxes_np, args):
    epoch_len = len(np.load(path_boxes_np))
    f_imgs, g_imgs, b_imgs, class_indexs = util.CAE_dataset_feed_dict(
        prefix, path_boxes_np, dataset_name=args.dataset)
    #former_batch,gray_batch,back_batch=util.CAE_dataset(path_boxes_np,args.dataset,epochs,batch_size)
    former_batch = tf.placeholder(dtype=tf.float32,
                                  shape=[batch_size, 64, 64, 1],
                                  name='former_batch')
    gray_batch = tf.placeholder(dtype=tf.float32,
                                shape=[batch_size, 64, 64, 1],
                                name='gray_batch')
    back_batch = tf.placeholder(dtype=tf.float32,
                                shape=[batch_size, 64, 64, 1],
                                name='back_batch')

    # * tf.image.image_gradients() 计算单张图片的x和y方向的梯度,与论文意思不一致
    # * 应修改为计算frame_{t} 和 frame_{t-3}及 frame_{t+3}的 帧差(absdiff)

    # grad1_x, grad1_y = tf.image.image_gradients(former_batch)
    # grad1=tf.concat([grad1_x,grad1_y],axis=-1)
    grad1 = tf.math.abs(tf.math.subtract(former_batch, gray_batch))
    # grad2_x,grad2_y=tf.image.image_gradients(gray_batch)
    # grad3_x, grad3_y = tf.image.image_gradients(back_batch)
    # grad3=tf.concat([grad3_x,grad3_y],axis=-1)
    grad3 = tf.math.abs(tf.math.subtract(back_batch, gray_batch))

    #grad_dis_1 = tf.sqrt(tf.square(grad1_x) + tf.square(grad1_y))
    #grad_dis_2 = tf.sqrt(tf.square(grad3_x) + tf.square(grad3_y))

    former_outputs = CAE.CAE(grad1, 'former', bn=args.bn, training=True)
    gray_outputs = CAE.CAE(gray_batch, 'gray', bn=args.bn, training=True)
    back_outputs = CAE.CAE(grad3, 'back', bn=args.bn, training=True)

    former_loss = CAE.pixel_wise_L2_loss(former_outputs, grad1)
    gray_loss = CAE.pixel_wise_L2_loss(gray_outputs, gray_batch)
    back_loss = CAE.pixel_wise_L2_loss(back_outputs, grad3)

    global_step = tf.Variable(0, dtype=tf.int32, trainable=False)
    global_step_a = tf.Variable(0, dtype=tf.int32, trainable=False)
    global_step_b = tf.Variable(0, dtype=tf.int32, trainable=False)

    lr_decay_epochs[0] = int(epoch_len // batch_size) * lr_decay_epochs[0]

    lr = tf.train.piecewise_constant(global_step,
                                     boundaries=lr_decay_epochs,
                                     values=learning_rate)

    former_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope='former_')
    gray_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope='gray_')
    back_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope='back_')
    # print(former_vars)
    if args.weight_reg != 0:
        former_loss = former_loss + args.weight_reg * weiht_regualized_loss(
            former_vars)
        gray_loss = gray_loss + args.weight_reg * weiht_regualized_loss(
            gray_vars)
        back_loss = back_loss + args.weight_reg * weiht_regualized_loss(
            back_vars)

    former_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(
        former_loss, var_list=former_vars, global_step=global_step)
    gray_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(
        gray_loss, var_list=gray_vars, global_step=global_step_a)
    back_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(
        back_loss, var_list=back_vars, global_step=global_step_b)

    step = 0
    if not args.bn:
        logdir = f'{summary_save_path_pre}/{args.dataset}'
    else:
        logdir = f'{summary_save_path_pre}/{args.dataset}_bn'
    writer = tf.summary.FileWriter(logdir=logdir)

    tf.summary.scalar('loss/former_loss', former_loss)
    tf.summary.scalar('loss/gray_loss', gray_loss)
    tf.summary.scalar('loss/back_loss', back_loss)
    #tf.summary.image('inputs/former',grad_dis_1)
    tf.summary.image('inputs/gray', gray_batch)
    #tf.summary.image('inputs/back',grad_dis_2)
    #tf.summary.image('outputs/former',former_outputs)
    tf.summary.image('outputs/gray', gray_outputs)
    #tf.summary.image('outputs/back',back_outputs)
    summary_op = tf.summary.merge_all()

    saver = tf.train.Saver(var_list=tf.global_variables())
    indices = list(range(epoch_len))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(epochs):
            random.shuffle(indices)
            for i in range(epoch_len // batch_size):
                feed_dict = {
                    former_batch: [
                        f_imgs[d]
                        for d in indices[i * batch_size:(i + 1) * batch_size]
                    ],
                    gray_batch: [
                        g_imgs[d]
                        for d in indices[i * batch_size:(i + 1) * batch_size]
                    ],
                    back_batch: [
                        b_imgs[d]
                        for d in indices[i * batch_size:(i + 1) * batch_size]
                    ]
                }
                step, _lr, _, _, _, _former_loss, _gray_loss, _back_loss = sess.run(
                    [
                        global_step, lr, former_op, gray_op, back_op,
                        former_loss, gray_loss, back_loss
                    ],
                    feed_dict=feed_dict)
                step_result = f'step{step}: lr={_lr:.4f}, fl={_former_loss:.4f}, gl={_gray_loss:.4f}, bl={_back_loss:.4f}'
                if step % 10 == 0:
                    print(step_result)

                if step % 50 == 0:
                    _summary = sess.run(summary_op, feed_dict=feed_dict)
                    writer.add_summary(_summary, global_step=step)
        if not args.bn:
            ckpt_path = f'{model_save_path_pre}{args.dataset}/{args.dataset}.ckpt'
        else:
            ckpt_path = f'{model_save_path_pre}{args.dataset}_bn/{args.dataset}.ckpt'
        saver.save(sess, ckpt_path)

        print('train finished!')
        sess.close()
Esempio n. 14
0
                        help='number of total epochs to run')
    args = parser.parse_args()
    return args


args = parse_opts()

print(args)

os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_id

# create tensorboard writer
cur_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))

if args.base_model == 'cae_4':
    model = CAE.CAE_4(data_len=1000, kernel_size=8, is_skip=args.is_skip)
elif args.base_model == 'cae_5':
    model = CAE.CAE_5(data_len=1000, kernel_size=8, is_skip=args.is_skip)
elif args.base_model == 'cae_6':
    model = CAE.CAE_6(data_len=1000, kernel_size=8, is_skip=args.is_skip)
elif args.base_model == 'cae_7':
    model = CAE.CAE_7(data_len=1000, kernel_size=8, is_skip=args.is_skip)
elif args.base_model == 'cae_8':
    model = CAE.CAE_8(data_len=1000, kernel_size=8, is_skip=args.is_skip)
elif args.base_model == 'cae_9':
    model = CAE.CAE_9(data_len=1000, kernel_size=8, is_skip=args.is_skip)
model.cuda()
optimizer = torch.optim.SGD(model.parameters(),
                            lr=args.lr,
                            momentum=0.9,
                            weight_decay=args.weight_decay)