Ejemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--gpu', '-g', type=int, default=-1, help='gpu id')
    args = parser.parse_args()

    # load model
    model = VGG16()
    print('Loading pretrained model from {0}'.format(MODEL_PATH))
    chainer.serializers.load_hdf5(MODEL_PATH, model)

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    chainer.config.train = False
    chainer.config.enable_backprop = False

    # prepare net input

    print('Loading image from {0}'.format(IMAGE_PATH))
    img = scipy.misc.imread(IMAGE_PATH, mode='RGB')
    img = scipy.misc.imresize(img, (224, 224))
    img_in = img.copy()

    img = img[:, :, ::-1]  # RGB -> BGR
    img = img.astype(np.float32)
    mean_bgr = np.array([104, 117, 123], dtype=np.float32)
    img -= mean_bgr

    x_data = np.array([img.transpose(2, 0, 1)])
    if args.gpu >= 0:
        x_data = chainer.cuda.to_gpu(x_data)
    x = chainer.Variable(x_data)

    # infer
    model(x)
    score = model.score.data[0]
    score = chainer.cuda.to_cpu(score)

    # visualize result

    likelihood = np.exp(score) / np.sum(np.exp(score))
    argsort = np.argsort(score)

    print('Loading label_names from {0}'.format(SYNSET_PATH))
    with open(SYNSET_PATH, 'r') as f:
        label_names = np.array([line.strip() for line in f.readlines()])

    print('Likelihood of top5:')
    top5 = argsort[::-1][:5]
    for index in top5:
        print('  {0:5.1f}%: {1}'
              .format(likelihood[index] * 100, label_names[index]))

    img_viz = draw_image_classification_top5(
        img_in, label_names[top5], likelihood[top5])
    out_file = osp.join(osp.dirname(IMAGE_PATH), 'result.jpg')
    plt.imsave(out_file, img_viz)
    print('Saved as: {0}'.format(out_file))
Ejemplo n.º 2
0
def main():

    if not os.path.exists(args.test_vis_dir + args.dataset):
        os.makedirs(args.test_vis_dir + args.dataset)

    model = VGG16()

    model.load_state_dict(torch.load(args.snapshot_dir + args.dataset + '_400000.pth'))

    model.eval()
    model.cuda()
    
    dataloader = DataLoader(FluxSegmentationDataset(dataset=args.dataset, mode='test'), batch_size=1, shuffle=False, num_workers=4)

    for i_iter, batch_data in enumerate(dataloader):

        Input_image, vis_image, gt_mask, gt_flux, weight_matrix, dataset_lendth, image_name = batch_data

        print(i_iter, dataset_lendth)

        pred_flux = model(Input_image.cuda())

        vis_flux(vis_image, pred_flux, gt_flux, gt_mask, image_name, args.test_vis_dir + args.dataset + '/')

        pred_flux = pred_flux.data.cpu().numpy()[0, ...]
        sio.savemat(args.test_vis_dir + args.dataset + '/' + image_name[0] + '.mat', {'flux': pred_flux})
Ejemplo n.º 3
0
def eval_it():
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    import argparse
    parser = argparse.ArgumentParser(description='Train model your dataset')
    parser.add_argument('--test_file',
                        default='test_pair.txt',
                        help='your train file',
                        type=str)

    parser.add_argument('--model_weights',
                        default='result_PFA_FLoss/PFA_00500.h5',
                        help='your model weights',
                        type=str)
    batch_size = 32

    args = parser.parse_args()
    model_name = args.model_weights
    test_path = args.test_file
    HOME = os.path.expanduser('~')
    test_folder = os.path.join(
        HOME,
        '../ads-creative-image-algorithm/public_data/datasets/SalientDataset/DUTS/DUTS-TE'
    )
    if not os.path.exists(test_path):
        ge_train_pair(test_path, test_folder, "DUTS-TE-Image", "DUTS-TE-Mask")
    target_size = (256, 256)
    f = open(test_path, 'r')
    testlist = f.readlines()
    f.close()
    steps_per_epoch = len(testlist) / batch_size
    optimizer = optimizers.SGD(lr=1e-2, momentum=0.9, decay=0)
    loss = EdgeHoldLoss
    metrics = [acc, pre, rec, F_value, MAE]
    with_crf = False
    draw_bound = False
    draw_poly = False
    draw_cutout = False
    dropout = False
    with_CPFE = True
    with_CA = True
    with_SA = True

    if target_size[0] % 32 != 0 or target_size[1] % 32 != 0:
        raise ValueError('Image height and wight must be a multiple of 32')
    testgen = getTestGenerator(test_path, target_size, batch_size)
    model_input = Input(shape=(target_size[0], target_size[1], 3))
    model = VGG16(model_input,
                  dropout=dropout,
                  with_CPFE=with_CPFE,
                  with_CA=with_CA,
                  with_SA=with_SA)
    model.load_weights(model_name, by_name=True)

    for layer in model.layers:
        layer.trainable = False
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    evalSal = model.evaluate_generator(testgen, steps_per_epoch - 1, verbose=1)
    print(evalSal)
Ejemplo n.º 4
0
def main():
    if not os.path.exists(args.test_vis_dir + args.dataset):
        os.makedirs(args.test_vis_dir + args.dataset)

    model = VGG16()

    model.load_state_dict(
        torch.load('PascalContext_400000.pth',
                   map_location=torch.device('cpu')))

    model.eval()
    # model.cuda()

    # dataloader = DataLoader(FluxSegmentationDataset(dataset=args.dataset, mode='test'), batch_size=1, shuffle=False, num_workers=4)

    # for i_iter, batch_data in enumerate(dataloader):
    image_dir = '..\\video frame\*'
    image_files = sorted(glob.glob(image_dir))
    IMAGE_MEAN = np.array([103.939, 116.779, 123.675], dtype=np.float32)
    for image_path in image_files:
        image_name = image_path.split('\\')[-1].split('.')[0]
        print(image_path, image_name)
        image = cv2.imread(image_path, 1)
        vis_image = image.copy()
        # print(vis_image.shape)

        image = image.astype(np.float32)
        image -= IMAGE_MEAN
        image = image.transpose(2, 0, 1)

        # Input_image, vis_image, gt_mask, gt_flux, weight_matrix, dataset_lendth, image_name = batch_data
        # print(i_iter, dataset_lendth)
        # pred_flux = model(Input_image.cuda())

        Input_image = torch.from_numpy(image).unsqueeze(0)
        with torch.no_grad() as f:
            pred_flux = model(Input_image)
        # print(pred_flux)

        vis_flux_v2(vis_image, pred_flux, image_name, args.test_vis_dir)
        # vis_flux(vis_image, pred_flux, gt_flux, gt_mask, image_name, args.test_vis_dir + args.dataset + '/')

        # pred_flux = pred_flux.data.cpu().numpy()[0, ...]
        pred_flux = pred_flux.numpy()[0, ...]
        sio.savemat(
            args.test_vis_dir + args.dataset + '/' + image_name + '.mat',
            {'flux': pred_flux})
Ejemplo n.º 5
0
def main():
    # load model
    model = VGG16()
    print('Loading pretrained model from {0}'.format(MODEL_PATH))
    chainer.serializers.load_hdf5(MODEL_PATH, model)

    # prepare net input

    print('Loading image from {0}'.format(IMAGE_PATH))
    img = scipy.misc.imread(IMAGE_PATH, mode='RGB')
    img = scipy.misc.imresize(img, (224, 224))
    img_in = img.copy()

    img = img[:, :, ::-1]  # RGB -> BGR
    img = img.astype(np.float32)
    mean_bgr = np.array([104, 117, 123], dtype=np.float32)
    img -= mean_bgr

    x_data = np.array([img.transpose(2, 0, 1)])
    x = chainer.Variable(x_data, volatile='ON')

    # infer
    model(x)
    score = model.score.data[0]

    # visualize result

    likelihood = np.exp(score) / np.sum(np.exp(score))
    argsort = np.argsort(score)

    print('Loading label_names from {0}'.format(SYNSET_PATH))
    with open(SYNSET_PATH, 'r') as f:
        label_names = np.array([line.strip() for line in f.readlines()])

    print('Likelihood of top5:')
    top5 = argsort[::-1][:5]
    for index in top5:
        print('  {0:5.1f}%: {1}'
              .format(likelihood[index]*100, label_names[index]))

    img_viz = draw_image_classification_top5(
        img_in, label_names[top5], likelihood[top5])
    plt.imshow(img_viz)
    plt.axis('off')
    plt.tight_layout()
    plt.show()
Ejemplo n.º 6
0
    def build_models(self):
        ################### Text and Image encoders ########################################
        if cfg.TRAIN.NET_E == '':
            print('Error: no pretrained text-image encoders')
            return
        """
        image_encoder = CNN_dummy()
        image_encoder.cuda()
        image_encoder.eval()
        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
        state_dict = \
            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
        image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters():
            p.requires_grad = False
        print('Load image encoder from:', img_encoder_path)
        image_encoder.eval()
        """

        VGG = VGG16()
        VGG.eval()

        text_encoder = \
            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
        state_dict = \
            torch.load(cfg.TRAIN.NET_E,
                       map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters():
            p.requires_grad = False
        print('Load text encoder from:', cfg.TRAIN.NET_E)
        text_encoder.eval()

        ####################### Generator and Discriminators ##############
        from model import D_NET256
        netD = D_NET256()
        netG = EncDecNet()

        netD.apply(weights_init)
        netG.apply(weights_init)

        #
        epoch = 0
        """
        if cfg.TRAIN.NET_G != '':
            state_dict = \
                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', cfg.TRAIN.NET_G)
            istart = cfg.TRAIN.NET_G.rfind('_') + 1
            iend = cfg.TRAIN.NET_G.rfind('.')
            epoch = cfg.TRAIN.NET_G[istart:iend]
            epoch = int(epoch) + 1
            if cfg.TRAIN.B_NET_D:
                Gname = cfg.TRAIN.NET_G
                for i in range(len(netsD)):
                    s_tmp = Gname[:Gname.rfind('/')]
                    Dname = '%s/netD%d.pth' % (s_tmp, i)
                    print('Load D from: ', Dname)
                    state_dict = \
                        torch.load(Dname, map_location=lambda storage, loc: storage)
                    netsD[i].load_state_dict(state_dict)
        """
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            netG.cuda()
            netD.cuda()
            VGG.cuda()
        return [text_encoder, netG, netD, epoch, VGG]
Ejemplo n.º 7
0
    def calc_mp(self):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for main module is not found!')
        else:
            #if split_dir == 'test':
            #    split_dir = 'valid'

            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = EncDecNet()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()
            # The text encoder

            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()
            # The image encoder
            #image_encoder = CNN_dummy()
            #print('define image_encoder')

            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                       'image_encoder')
            state_dict = \
                torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)

            image_encoder = image_encoder.cuda()
            image_encoder.eval()

            # The VGG network
            VGG = VGG16()
            print("Load the VGG model")
            #VGG.to(torch.device("cuda:1"))
            VGG.cuda()
            VGG.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
            noise = noise.cuda()

            model_dir = os.path.join(cfg.DATA_DIR, 'output', self.args.netG,
                                     'Model', self.args.netG_epoch)
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save modified images

            cnt = 0
            idx = 0
            diffs, sims = [], []
            for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)

                    imgs, w_imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data)

                    #######################################################
                    # (1) Extract text and image embeddings
                    ######################################################

                    hidden = text_encoder.init_hidden(batch_size)

                    words_embs, sent_emb = text_encoder(
                        wrong_caps, wrong_caps_len, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()

                    mask = (wrong_caps == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    #######################################################
                    # (2) Modify real images
                    ######################################################

                    noise.data.normal_(0, 1)

                    fake_img, mu, logvar = netG(imgs[-1], sent_emb, words_embs,
                                                noise, mask, VGG)

                    diff = F.l1_loss(fake_img, imgs[-1])
                    diffs.append(diff.item())

                    region_code, cnn_code = image_encoder(fake_img)

                    sim = cosine_similarity(sent_emb, cnn_code)
                    sim = torch.mean(sim)
                    sims.append(sim.item())

            diff = np.sum(diffs) / len(diffs)
            sim = np.sum(sims) / len(sims)
            print('diff: %.3f, sim:%.3f' % (diff, sim))
            print('MP: %.3f' % ((1 - diff) * sim))
            netG_epoch = self.args.netG_epoch[self.args.netG_epoch.find('_') +
                                              1:-4]
            print('model_epoch:%s, diff: %.3f, sim:%.3f, MP:%.3f' %
                  (netG_epoch, np.sum(diffs) / len(diffs),
                   np.sum(sims) / len(sims), (1 - diff) * sim))
def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


direc = image_dir + 'content_images/'
content_images = datasets.ImageFolder(direc, transform=load_content)

direc = image_dir + 'style_images/'
style_images = datasets.ImageFolder(direc, transform=load_style)

style_images = style_images[0][0].unsqueeze(0)
vgg = VGG16()
model = ImageTransformationNetwork(img_size)

if gpu:
    style_images_var = Variable(style_images.cuda(), requires_grad=False)
    vgg = vgg.cuda()
    model = model.cuda()
    style_loss_fns = [StyleReconstructionLoss().cuda()] * 4
    content_loss_fns = [FeatureReconstructionLoss().cuda()]
else:
    style_images_var = Variable(style_images, requires_grad=False)
    style_loss_fns = [StyleReconstructionLoss()] * 4
    content_loss_fns = [FeatureReconstructionLoss()]

for param in vgg.parameters():
    param.requires_grad = False
    model_save_period = 5
    if target_size[0] % 32 != 0 or target_size[1] % 32 != 0:
        raise ValueError('{} {}'.format('Image height and weight must',
                                        'be a multiple of 32'))

    # Training data generator and or shuffler
    traingen = getTrainGenerator(train_path,
                                 target_size,
                                 batch_size,
                                 israndom=False)

    # Model definition and options
    model_input = Input(shape=(target_size[0], target_size[1], 3))
    model = VGG16(model_input,
                  dropout=dropout,
                  with_CPFE=with_CPFE,
                  with_CA=with_CA,
                  with_SA=with_SA)
    model.load_weights(model_name, by_name=True)

    # Tensorflow & Tensorboard options
    tb = callbacks.TensorBoard(log_dir=tb_log)
    lr_decay = callbacks.LearningRateScheduler(schedule=lr_scheduler)
    es = callbacks.EarlyStopping(monitor='loss',
                                 patience=3,
                                 verbose=0,
                                 mode='auto')
    modelcheck = callbacks.ModelCheckpoint(model_save + '{epoch:05d}.h5',
                                           monitor='loss',
                                           verbose=1,
                                           save_best_only=False,
    cfg.data_path, 'train'),
                                              transform=transform_train)

test_data = torchvision.datasets.ImageFolder(os.path.join(
    cfg.data_path, 'test'),
                                             transform=transform_test)

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=cfg.bs,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=cfg.bs,
                                          shuffle=True)

net = VGG16.Net()

if cfg.resume:
    print('------------------------------')
    print('==> Loading the checkpoint ')
    if not os.path.exists(cfg.ckpt_path):
        raise AssertionError['Can not find path']
    checkpoint = torch.load(cfg.ckpt_path)
    net.load_state_dict(checkpoint['net'])
    best_test_acc = checkpoint['best_test_acc']
    print('best_test_acc is %.4f%%' % best_test_acc)
    best_test_acc_epoch = checkpoint['best_test_acc_epoch']
    print('best_test_acc_epoch is %d' % best_test_acc_epoch)
    start_epoch = checkpoint['best_test_acc_epoch'] + 1
else:
    print('------------------------------')
Ejemplo n.º 11
0
def test(FLAG):
    print("Reading dataset...")
    # load data
    file_list = [
        FLAG.test_dir + file.replace('_sat.jpg', '')
        for file in os.listdir(FLAG.test_dir) if file.endswith('_sat.jpg')
    ]
    file_list.sort()
    Xtest = read_list(file_list, with_mask=False)

    vgg16 = VGG16(classes=7, shape=(256, 256, 3))
    vgg16.build(vgg16_npy_path=FLAG.init_from, mode=FLAG.mode)

    def initialize_uninitialized(sess):
        global_vars = tf.global_variables()
        is_not_initialized = sess.run(
            [tf.is_variable_initialized(var) for var in global_vars])
        not_initialized_vars = [
            v for (v, f) in zip(global_vars, is_not_initialized) if not f
        ]
        if len(not_initialized_vars):
            sess.run(tf.variables_initializer(not_initialized_vars))

    with tf.Session() as sess:
        if FLAG.save_dir is not None:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(FLAG.save_dir)

            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                print("Model restored %s" % ckpt.model_checkpoint_path)
                sess.run(tf.global_variables())
            else:
                print("No model checkpoint in %s" % FLAG.save_dir)
        else:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.global_variables())
        print("Initialized")

        print("Plot saved in %s" % FLAG.plot_dir)
        for i, fname in enumerate(file_list):
            Xplot = sess.run(
                vgg16.pred,
                feed_dict={
                    vgg16.x: Xtest[i:(i + 1), :],
                    #vgg16.y: Ytest[i:(i+1),:],
                    vgg16.is_train: False
                })
            saveimg = skimage.transform.resize(Xplot[0],
                                               output_shape=(512, 512),
                                               order=0,
                                               preserve_range=True,
                                               clip=False)
            saveimg = label2rgb(saveimg)
            imageio.imsave(
                os.path.join(FLAG.plot_dir,
                             os.path.basename(fname) + "_mask.png"), saveimg)
            print(
                os.path.join(FLAG.plot_dir,
                             os.path.basename(fname) + "_mask.png"))
def camera_predict():

    video_captor = cv2.VideoCapture(0)
    predicted_class = None
    while True:
        ret, frame = video_captor.read()

        face_img, face_coor = camera_face_detect.face_d(frame)

        if face_coor is not None:
            [x_screen, y_screen, w_screen, h_screen] = face_coor
            cv2.rectangle(frame, (x_screen, y_screen),
                          (x_screen + w_screen, y_screen + h_screen),
                          (255, 0, 0), 2)

        if cv2.waitKey(1) & 0xFF == ord(' '):
            face_img, face_coor = camera_face_detect.face_d(frame)

            if face_coor is not None:
                [x, y, w, h] = face_coor
                cv2.rectangle(frame, (x, y), (x + w, y + h), (255, 0, 0), 2)

            if face_img is not None:
                if not os.path.exists(cfg.output_folder):
                    os.mkdir('./camera_output')
                cv2.imwrite(os.path.join(cfg.output_folder, 'face_image.jpg'),
                            face_img)
                gray = cv2.cvtColor(face_img, cv2.COLOR_RGB2GRAY)
                gray = cv2.resize(gray, (240, 240))
                img = gray[:, :, np.newaxis]
                img = np.concatenate((img, img, img), axis=2)
                img = Image.fromarray(np.uint8(img))
                inputs = transform_test(img)
                class_names = [
                    'Angry', 'Disgusted', 'Fearful', 'Happy', 'Sad',
                    'Surprised', 'Neutral'
                ]

                net = VGG16.Net()
                checkpoint = torch.load(cfg.ckpt_path)
                net.load_state_dict(checkpoint['net'])

                if use_cuda:
                    net.to(device)

                net.eval()
                ncrops, c, h, w = np.shape(inputs)
                inputs = inputs.view(-1, c, h, w)
                inputs = Variable(inputs, volatile=True)
                if use_cuda:
                    inputs = inputs.to(device)
                outputs = net(inputs)
                outputs_avg = outputs.view(ncrops, -1).mean(0)
                score = F.softmax(outputs_avg)
                print(score)
                _, predicted = torch.max(outputs_avg.data, 0)
                predicted_class = class_names[int(predicted.cpu().numpy())]
                print(predicted_class)

            if predicted_class is not None:
                cv2.putText(frame, predicted_class, (30, 60),
                            cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 1)
                cv2.imwrite(os.path.join(cfg.output_folder, 'predict.jpg'),
                            frame)

        if predicted_class is not None:
            cv2.putText(frame, predicted_class, (30, 60),
                        cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 1)
        cv2.imshow('camera', frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
Ejemplo n.º 13
0
from model import VGG16
from utils import Data
import h5py
import pickle
import numpy as np

root_path = '/home/jhy/Desktop/ece6254/ece6254_data'
vgg16_weights_path = '/home/jhy/Desktop/ece6254/ece6254_data/vgg16_weights.h5'
our_weights_path = '/home/jhy/Desktop/ece6254/ece6254_data/weights.h5'

data = Data(root_path, load_size=1)
model = VGG16(our_weights_path, vgg16_weights_path)

model.test('SVM')
'''X,y = model.load_features(root_path)
X,y = np.array(X), np.array(y)
print X.shape, y.shape'''

#Y,X =  model.predict(data, batch_size=32, nb_epoch=1)
#output1 = open('feature.pkl','wb')
#output2 = open('label.pkl','wb')
#pickle.dump(X,output1,-1)
#pickle.dump(Y,output2,-1)
#output1.close()
#output2.close()
Ejemplo n.º 14
0
def main():
    global best_accuracy
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=4, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model', action='store_true', default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    ###############################################################################################
    transform = transforms.Compose([
        transforms.Resize((224, 224), Image.BICUBIC),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: torch.cat([x, x, x], 0)),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    # change Mnist data to RGB channel
    # change input size from [1,28,28] -> [1,224,224]
    # change 1 dimension to 3 dimension [1,224,224] ->[3,224,224]

    ###############################################################################################

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True, transform=transform),
        batch_size=args.batch_size,
        shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transform),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

    model = VGG16().to(device)
    optimizer = optim.Adam(model.parameters())

    print(model)

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)

        accuracy = test(args, model, device, test_loader)

        if (args.save_model and accuracy > best_accuracy):
            best_accuracy = accuracy
            torch.save(model.state_dict(), "best_mnist_vgg.pt")
            print("best accuracy model is updated")
Ejemplo n.º 15
0
def main():

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    if not os.path.exists(args.train_debug_vis_dir + args.dataset):
        os.makedirs(args.train_debug_vis_dir + args.dataset)

    model = VGG16()

    saved_dict = torch.load('vgg16_pretrain.pth')
    model_dict = model.state_dict()
    saved_key = list(saved_dict.keys())
    model_key = list(model_dict.keys())

    for i in range(26):
        model_dict[model_key[i]] = saved_dict[saved_key[i]]

    model.load_state_dict(model_dict)

    model.train()
    model.cuda()

    optimizer = torch.optim.Adam(params=[
        {
            "params": get_params(model, key="backbone", bias=False),
            "lr": INI_LEARNING_RATE
        },
        {
            "params": get_params(model, key="backbone", bias=True),
            "lr": 2 * INI_LEARNING_RATE
        },
        {
            "params": get_params(model, key="added", bias=False),
            "lr": 10 * INI_LEARNING_RATE
        },
        {
            "params": get_params(model, key="added", bias=True),
            "lr": 20 * INI_LEARNING_RATE
        },
    ],
                                 weight_decay=WEIGHT_DECAY)

    dataloader = DataLoader(FluxSegmentationDataset(dataset=args.dataset,
                                                    mode='train'),
                            batch_size=1,
                            shuffle=True,
                            num_workers=4)

    global_step = 0

    for epoch in range(1, EPOCHES):

        for i_iter, batch_data in enumerate(dataloader):

            global_step += 1

            Input_image, vis_image, gt_mask, gt_flux, weight_matrix, dataset_lendth, image_name = batch_data

            optimizer.zero_grad()

            pred_flux = model(Input_image.cuda())

            norm_loss, angle_loss = loss_calc(pred_flux, gt_flux,
                                              weight_matrix)

            total_loss = norm_loss + angle_loss

            total_loss.backward()

            optimizer.step()

            if global_step % 100 == 0:
                print('epoche {} i_iter/total {}/{} norm_loss {:.2f} angle_loss {:.2f}'.format(\
                       epoch, i_iter, int(dataset_lendth.data), norm_loss, angle_loss))

            if global_step % 500 == 0:
                vis_flux(vis_image, pred_flux, gt_flux, gt_mask, image_name,
                         args.train_debug_vis_dir + args.dataset + '/')

            if global_step % 1e4 == 0:
                torch.save(
                    model.state_dict(), args.snapshot_dir + args.dataset +
                    '_' + str(global_step) + '.pth')

            if global_step % 4e5 == 0:
                return
Ejemplo n.º 16
0
def train(FLAG):
    print("Reading dataset...")
    if FLAG.dataset == 'CIFAR-10':
        train_data = CIFAR10(train=True)
        test_data = CIFAR10(train=False)
        vgg16 = VGG16(classes=10)
    elif FLAG.dataset == 'CIFAR-100':
        train_data = CIFAR100(train=True)
        test_data = CIFAR100(train=False)
        vgg16 = VGG16(classes=100)
    else:
        raise ValueError("dataset should be either CIFAR-10 or CIFAR-100.")
    print("Build VGG16 models for %s..." % FLAG.dataset)

    Xtrain, Ytrain = train_data.train_data, train_data.train_labels
    Xtest, Ytest = test_data.test_data, test_data.test_labels

    vgg16.build(vgg16_npy_path=FLAG.init_from,
                prof_type=FLAG.prof_type,
                conv_pre_training=True,
                fc_pre_training=False)
    vgg16.sparsity_train(l1_gamma=FLAG.lambda_s,
                         l1_gamma_diff=FLAG.lambda_m,
                         decay=FLAG.decay,
                         keep_prob=FLAG.keep_prob)

    # define tasks
    tasks = ['var_dp']
    print(tasks)

    # initial task
    cur_task = tasks[0]
    obj = vgg16.loss_dict[tasks[0]]

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=len(tasks))

    checkpoint_path = os.path.join(FLAG.save_dir, 'model.ckpt')
    tvars_trainable = tf.trainable_variables()

    #for rm in vgg16.gamma_var:
    #    tvars_trainable.remove(rm)
    #    print('%s is not trainable.'% rm)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # hyper parameters
        batch_size = 64
        epoch = 500
        early_stop_patience = 50
        min_delta = 0.0001
        opt_type = 'adam'

        # recorder
        epoch_counter = 0

        # optimizer
        global_step = tf.Variable(0, trainable=False)

        # Passing global_step to minimize() will increment it at each step.
        if opt_type is 'sgd':
            start_learning_rate = 1e-4  # adam # 4e-3 #sgd
            half_cycle = 20000
            learning_rate = tf.train.exponential_decay(start_learning_rate,
                                                       global_step,
                                                       half_cycle,
                                                       0.5,
                                                       staircase=True)
            opt = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                             momentum=0.9,
                                             use_nesterov=True)
        else:
            start_learning_rate = 1e-4  # adam # 4e-3 #sgd
            half_cycle = 10000
            learning_rate = tf.train.exponential_decay(start_learning_rate,
                                                       global_step,
                                                       half_cycle,
                                                       0.5,
                                                       staircase=True)
            opt = tf.train.AdamOptimizer(learning_rate=learning_rate)

        train_op = opt.minimize(obj,
                                global_step=global_step,
                                var_list=tvars_trainable)

        # progress bar
        ptrain = IntProgress()
        pval = IntProgress()
        display(ptrain)
        display(pval)
        ptrain.max = int(Xtrain.shape[0] / batch_size)
        pval.max = int(Xtest.shape[0] / batch_size)

        spareness = vgg16.spareness(thresh=0.05)
        print("initial spareness: %s" % sess.run(spareness))

        # re-initialize
        initialize_uninitialized(sess)

        # reset due to adding a new task
        patience_counter = 0
        current_best_val_accu = 0

        # optimize when the aggregated obj
        while (patience_counter < early_stop_patience
               and epoch_counter < epoch):

            def load_batches():
                for i in range(int(Xtrain.shape[0] / batch_size)):
                    st = i * batch_size
                    ed = (i + 1) * batch_size
                    batch = ia.Batch(images=Xtrain[st:ed, :, :, :],
                                     data=Ytrain[st:ed, :])
                    yield batch

            batch_loader = ia.BatchLoader(load_batches)
            bg_augmenter = ia.BackgroundAugmenter(batch_loader=batch_loader,
                                                  augseq=transform,
                                                  nb_workers=4)

            # start training
            stime = time.time()
            bar_train = Bar(
                'Training',
                max=int(Xtrain.shape[0] / batch_size),
                suffix='%(index)d/%(max)d - %(percent).1f%% - %(eta)ds')
            bar_val = Bar(
                'Validation',
                max=int(Xtest.shape[0] / batch_size),
                suffix='%(index)d/%(max)d - %(percent).1f%% - %(eta)ds')
            train_loss, train_accu = 0.0, 0.0
            while True:
                batch = bg_augmenter.get_batch()
                if batch is None:
                    print("Finished epoch.")
                    break
                x_images_aug = batch.images_aug
                y_images = batch.data
                loss, accu, _ = sess.run(
                    [obj, vgg16.accu_dict[cur_task], train_op],
                    feed_dict={
                        vgg16.x: x_images_aug,
                        vgg16.y: y_images,
                        vgg16.is_train: True
                    })
                bar_train.next()
                train_loss += loss
                train_accu += accu
                ptrain.value += 1
                ptrain.description = "Training %s/%s" % (ptrain.value,
                                                         ptrain.max)
            train_loss = train_loss / ptrain.value
            train_accu = train_accu / ptrain.value
            batch_loader.terminate()
            bg_augmenter.terminate()

            # # training an epoch
            # for i in range(int(Xtrain.shape[0]/batch_size)):
            #     st = i*batch_size
            #     ed = (i+1)*batch_size

            #     augX = transform.augment_images(Xtrain[st:ed,:,:,:])

            #     sess.run([train_op], feed_dict={vgg16.x: augX,
            #                                     vgg16.y: Ytrain[st:ed,:],
            #                                     vgg16.is_train: False})
            #     ptrain.value +=1
            #     ptrain.description = "Training %s/%s" % (i, ptrain.max)
            #     bar_train.next()

            # validation
            val_loss = 0
            val_accu = 0
            for i in range(int(Xtest.shape[0] / 200)):
                st = i * 200
                ed = (i + 1) * 200
                loss, accu = sess.run(
                    [obj, vgg16.accu_dict[cur_task]],
                    feed_dict={
                        vgg16.x: Xtest[st:ed, :],
                        vgg16.y: Ytest[st:ed, :],
                        vgg16.is_train: False
                    })
                val_loss += loss
                val_accu += accu
                pval.value += 1
                pval.description = "Testing %s/%s" % (pval.value, pval.value)
            val_loss = val_loss / pval.value
            val_accu = val_accu / pval.value

            print("\nspareness: %s" % sess.run(spareness))
            # early stopping check
            if (val_accu - current_best_val_accu) > min_delta:
                current_best_val_accu = val_accu
                patience_counter = 0

                para_dict = sess.run(vgg16.para_dict)
                np.save(os.path.join(FLAG.save_dir, "para_dict.npy"),
                        para_dict)
                print("save in %s" %
                      os.path.join(FLAG.save_dir, "para_dict.npy"))
            else:
                patience_counter += 1

            # shuffle Xtrain and Ytrain in the next epoch
            idx = np.random.permutation(Xtrain.shape[0])
            Xtrain, Ytrain = Xtrain[idx, :, :, :], Ytrain[idx, :]

            # epoch end
            # writer.add_summary(epoch_summary, epoch_counter)
            epoch_counter += 1

            ptrain.value = 0
            pval.value = 0
            bar_train.finish()
            bar_val.finish()

            print(
                "Epoch %s (%s), %s sec >> train loss: %.4f, train accu: %.4f, val loss: %.4f, val accu at %s: %.4f"
                % (epoch_counter, patience_counter,
                   round(time.time() - stime, 2), train_loss, train_accu,
                   val_loss, cur_task, val_accu))
        saver.save(sess, checkpoint_path, global_step=epoch_counter)

        sp, rcut = gammaSparsifyVGG16(para_dict, thresh=0.02)
        np.save(os.path.join(FLAG.save_dir, "sparse_dict.npy"), sp)
        print("sparsify %s in %s" % (np.round(
            1 - rcut, 3), os.path.join(FLAG.save_dir, "sparse_dict.npy")))

        #writer.close()
        arr_spareness.append(1 - rcut)
        np.save(os.path.join(FLAG.save_dir, "sprocess.npy"), arr_spareness)
    FLAG.optimizer = opt_type
    FLAG.lr = start_learning_rate
    FLAG.batch_size = batch_size
    FLAG.epoch_end = epoch_counter
    FLAG.val_accu = current_best_val_accu

    header = ''
    row = ''
    for key in sorted(vars(FLAG)):
        if header is '':
            header = key
            row = str(getattr(FLAG, key))
        else:
            header += "," + key
            row += "," + str(getattr(FLAG, key))
    row += "\n"
    header += "\n"
    if os.path.exists("/home/cmchang/new_CP_CNN/model.csv"):
        with open("/home/cmchang/new_CP_CNN/model.csv", "a") as myfile:
            myfile.write(row)
    else:
        with open("/home/cmchang/new_CP_CNN/model.csv", "w") as myfile:
            myfile.write(header)
            myfile.write(row)
Ejemplo n.º 17
0
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(save_dir)
        pred = []
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            sess.run(tf.global_variables())
            for i in range(test.images.shape[0]):
                output = sess.run(model.output,
                                   feed_dict={model.x: test.images[i:(i+1),:],
                                              model.y: test.labels[i:(i+1),:],
                                              model.w: [[1.0]]
                                             })

                pred.append(output)
            pred = np.reshape(pred,newshape=(-1,1))
            plot_making(save_dir, test.labels, pred, types="test")
            
            print("mean square error: %.4f" % np.mean(np.square(test.labels-pred)))

# with tf.variable_scope("test") as scope:
vgg16 = VGG16(vgg16_npy_path=init_from)
vgg16.build()

## training start ##
vgg16_train(vgg16, train, test, save_dir=save_dir, init_from=init_from, batch_size=64, epoch=300, early_stop_patience=25)
## save result ## 
save_plot(vgg16, save_dir, test)

Ejemplo n.º 18
0
def train():
    data_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    train_set=PlantSeedlingData(Path(ROOTDIR).joinpath('train'), data_transform)

    valid_size=0.25
    num_train=len(train_set)
    indices=list(range(num_train))
    np.random.shuffle(indices)
    split=int(np.floor(valid_size*num_train))
    train_idx,valid_idx=indices[split:],indices[:split]

    train_sampler=SubsetRandomSampler(train_idx)
    valid_sampler=SubsetRandomSampler(valid_idx)
    train_loader=DataLoader(train_set,batch_size=32,sampler=train_sampler,num_workers=1)
    valid_loader=DataLoader(train_set,batch_size=32,sampler=valid_sampler,num_workers=1)
    

    # device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Input model: ")
    which_model=input()
    if which_model=="vgg16":
        model = VGG16(num_classes=train_set.num_classes)
        num_epochs=50
    elif which_model=="vgg19":
        model=VGG19(num_classes=train_set.num_classes)
        num_epochs=100
    elif which_model=="googlenet":
        model=GOOGLENET(num_classes=train_set.num_classes)
        num_epochs=100

    if torch.cuda.is_available():
        model=model.cuda("cuda:0")
    model.train()

    best_model_params = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)
    loss_acc=np.empty((0,4),dtype=float)

    early_stopping=EarlyStopping(patience=15,verbose=True)

    for epoch in range(num_epochs):
        print(f'Epoch: {epoch + 1}/{num_epochs}')
        print('-' * len(f'Epoch: {epoch + 1}/{num_epochs}'))

        training_loss = 0.0
        training_corrects = 0
        valid_loss=0.0
        valid_corrects=0

        for i, (inputs, labels) in enumerate(train_loader):
            if torch.cuda.is_available():
                inputs = Variable(inputs.cuda("cuda:0"))
                labels = Variable(labels.cuda("cuda:0"))
            else:
                inputs=Variable(inputs)
                labels=Variable(labels)

            optimizer.zero_grad()

            outputs = model(inputs)
            _, preds = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()

            training_loss += loss.item() * inputs.size(0)
            training_corrects += torch.sum(preds == labels.data)

        training_loss = training_loss / (len(train_set)-split)
        training_acc = float(training_corrects) / (len(train_set)-split)

        model.eval()
        for _data,target in valid_loader:
            outputs=model(_data)
            _, preds = torch.max(outputs.data, 1)
            loss=criterion(outputs,target)

            valid_loss+=loss.item()*_data.size(0)
            valid_corrects+=torch.sum(preds==target.data)
        
        valid_loss=valid_loss/split
        valid_acc=float(valid_corrects)/split
        
        loss_acc=np.append(loss_acc,np.array([[training_loss,training_acc,valid_loss,valid_acc]]),axis=0)

        print_msg=(f'train_loss: {training_loss:.4f} valid_loss: {valid_loss:.4f}\t'+
                   f'train_acc: {training_acc:.4f} valid_acc: {valid_acc:.4f}')
        print(print_msg)

        early_stopping(valid_loss,model)
        if early_stopping.early_stop:
            print("Early Stopping")
            break

    loss_acc=np.round(loss_acc,4)
    np.savetxt('googlenet4-train_loss_acc.csv',loss_acc,delimiter=',')
    
    model.load_state_dict(torch.load('checkpoint4.pt'))
    torch.save(model,'googlenet4-best-train-acc.pth')
Ejemplo n.º 19
0
            cv2.imread("../data/test/%d.png" %
                       (i), cv2.IMREAD_GRAYSCALE) - 5) / 50
    ] for i in batch])
    #print(x.shape)
    y = np.asarray([Y[i] for i in batch])
    #print(y.shape)
    x = torch.tensor(x)
    x = x.float()
    y = np.asarray(y)
    y = torch.tensor(y)
    y = y.float()
    return x, y


model = VGG16(kernel_size=kernel_size,
              padding_size=padding_size,
              dropout_prob=dropout_prob)
try:
    model.load_state_dict(torch.load('../data/mytraining.pt'))
    write('Model succesfully loaded.\n')
except:
    write(
        "Training file not found/ incompatible with model/ some other error.\n"
    )
model.to(device)

model = model.eval()

global_y = np.asarray([])
global_y_test = np.asarray([])
Ejemplo n.º 20
0
    def sampling(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for main module is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'

            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = EncDecNet()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()
            # The text encoder
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()
            # The image encoder
            """
            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
            state_dict = \
                torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)
            image_encoder = image_encoder.cuda()
            image_encoder.eval()
            """

            # The VGG network
            VGG = VGG16()
            print("Load the VGG model")
            VGG.cuda()
            VGG.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
            noise = noise.cuda()

            model_dir = os.path.join(cfg.DATA_DIR, 'output', self.args.netG,
                                     'Model/netG_epoch_600.pth')
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save modified images
            save_dir_valid = os.path.join(cfg.DATA_DIR, 'output',
                                          self.args.netG, 'valid')
            #mkdir_p(save_dir)

            cnt = 0
            idx = 0
            for i in range(5):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                # the path to save modified images
                save_dir = os.path.join(save_dir_valid, 'valid_%d' % i)
                save_dir_super = os.path.join(save_dir, 'super')
                save_dir_single = os.path.join(save_dir, 'single')
                mkdir_p(save_dir_super)
                mkdir_p(save_dir_single)
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)

                    imgs, w_imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data)

                    #######################################################
                    # (1) Extract text and image embeddings
                    ######################################################

                    hidden = text_encoder.init_hidden(batch_size)

                    words_embs, sent_emb = text_encoder(
                        wrong_caps, wrong_caps_len, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()

                    mask = (wrong_caps == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    #######################################################
                    # (2) Modify real images
                    ######################################################

                    noise.data.normal_(0, 1)

                    fake_img, mu, logvar = netG(imgs[-1], sent_emb, words_embs,
                                                noise, mask, VGG)

                    img_set = build_images(imgs[-1], fake_img, captions,
                                           wrong_caps, self.ixtoword)
                    img = Image.fromarray(img_set)
                    full_path = '%s/super_step%d.png' % (save_dir_super, step)
                    img.save(full_path)

                    for j in range(batch_size):
                        s_tmp = '%s/single' % (save_dir_single)
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = -1
                        im = fake_img[j].data.cpu().numpy()
                        #im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d.png' % (s_tmp, idx)
                        idx = idx + 1
                        im.save(fullpath)
Ejemplo n.º 21
0
    def gen_example(self, data_dic):
        if cfg.TRAIN.NET_G == '' or cfg.TRAIN.NET_C == '':
            print('Error: the path for main module or DCM is not found!')
        else:
            # The text encoder
            text_encoder = \
                RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            # The image encoder
            """
            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
            state_dict = \
                torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load image encoder from:', img_encoder_path)
            image_encoder = image_encoder.cuda()
            image_encoder.eval()
            """
            """
            image_encoder = CNN_dummy()
            image_encoder = image_encoder.cuda()
            image_encoder.eval()
            """

            # The VGG network
            VGG = VGG16()
            print("Load the VGG model")
            VGG.cuda()
            VGG.eval()

            # The main module
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = EncDecNet()
            s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')]
            s_tmp = os.path.join(cfg.DATA_DIR, 'output', self.args.netG,
                                 'valid/gen_example')

            model_dir = os.path.join(cfg.DATA_DIR, 'output', self.args.netG,
                                     'Model/netG_epoch_8.pth')
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)
            #netG = nn.DataParallel(netG, device_ids= self.gpus)
            netG.cuda()
            netG.eval()

            for key in data_dic:
                save_dir = '%s/%s' % (s_tmp, key)
                mkdir_p(save_dir)
                captions, cap_lens, sorted_indices, imgs = data_dic[key]

                batch_size = captions.shape[0]
                nz = cfg.GAN.Z_DIM
                captions = Variable(torch.from_numpy(captions), volatile=True)
                cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)

                captions = captions.cuda()
                cap_lens = cap_lens.cuda()
                for i in range(1):
                    noise = Variable(torch.FloatTensor(batch_size, nz),
                                     volatile=True)
                    noise = noise.cuda()

                    #######################################################
                    # (1) Extract text and image embeddings
                    ######################################################
                    hidden = text_encoder.init_hidden(batch_size)

                    # The text embeddings
                    words_embs, sent_emb = text_encoder(
                        captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()

                    # The image embeddings
                    mask = (captions == 0)
                    #######################################################
                    # (2) Modify real images
                    ######################################################
                    noise.data.normal_(0, 1)

                    imgs_256 = imgs[-1].unsqueeze(0).repeat(
                        batch_size, 1, 1, 1)
                    enc_features = VGG(imgs_256)
                    fake_img, mu, logvar = nn.parallel.data_parallel(
                        netG, (imgs[-1], sent_emb, words_embs, noise, mask,
                               enc_features), self.gpus)

                    cap_lens_np = cap_lens.cpu().data.numpy()

                    one_imgs = []
                    for j in range(captions.shape[0]):
                        font = ImageFont.truetype('./FreeMono.ttf', 20)
                        canv = Image.new('RGB', (256, 256), (255, 255, 255))
                        draw = ImageDraw.Draw(canv)
                        sent = []
                        for k in range(len(captions[j])):
                            if (captions[j][k] == 0):
                                break
                            word = self.ixtoword[captions[j][k].item()].encode(
                                'ascii', 'ignore').decode('ascii')
                            if (k % 2 == 1):
                                word = word + '\n'
                            sent.append(word)
                        fake_sent = ' '.join(sent)
                        draw.text((0, 0), fake_sent, font=font, fill=(0, 0, 0))
                        canv_np = np.asarray(canv)

                        real_im = imgs[-1]
                        real_im = (real_im + 1) * 127.5
                        real_im = real_im.cpu().numpy().astype(np.uint8)
                        real_im = np.transpose(real_im, (1, 2, 0))

                        fake_im = fake_img[j]
                        fake_im = (fake_im + 1.0) * 127.5
                        fake_im = fake_im.detach().cpu().numpy().astype(
                            np.uint8)
                        fake_im = np.transpose(fake_im, (1, 2, 0))

                        one_img = np.concatenate([real_im, canv_np, fake_im],
                                                 axis=1)
                        one_imgs.append(one_img)

                    img_set = np.concatenate(one_imgs, axis=0)
                    super_img = Image.fromarray(img_set)
                    full_path = os.path.join(save_dir, 'super.png')
                    super_img.save(full_path)
                    """
                    for j in range(5): ## batch_size
                        save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j])
                        for k in range(len(fake_imgs)):
                            im = fake_imgs[k][j].data.cpu().numpy()
                            im = (im + 1.0) * 127.5
                            im = im.astype(np.uint8)
                            im = np.transpose(im, (1, 2, 0))
                            im = Image.fromarray(im)
                            fullpath = '%s_g%d.png' % (save_name, k)
                            im.save(fullpath)

                        for k in range(len(attention_maps)):
                            if len(fake_imgs) > 1:
                                im = fake_imgs[k + 1].detach().cpu()
                            else:
                                im = fake_imgs[0].detach().cpu()
                            attn_maps = attention_maps[k]
                            att_sze = attn_maps.size(2)
                    """
                    """
                            img_set, sentences = \
                                build_super_images2(im[j].unsqueeze(0),
                                                    captions[j].unsqueeze(0),
                                                    [cap_lens_np[j]], self.ixtoword,
                                                    [attn_maps[j]], att_sze)
                            if img_set is not None:
                                im = Image.fromarray(img_set)
                                fullpath = '%s_a%d.png' % (save_name, k)
                                im.save(fullpath)
                    """
                    """
Ejemplo n.º 22
0
from model import LanguageModel, VQA_FeatureModel, VGG16
from data_loader import ImageFeatureDataset
from data_utils import change, preprocess_text

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
import torchvision.models as models

with open('dumps/val_features_vgg16.pkl', 'rb') as f:
    v = pickle.load(f)
paths = [x[1] for x in v]

vgg16 = VGG16()
print('Loaded VGG16 Model')
vgg16 = vgg16.eval()

img_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


def get_feature(path):

    img = skimage.io.imread(path)

    if (len(img.shape) == 2):
Ejemplo n.º 23
0
import os

import tensorflow as tf
from model import VGG16, VGG16Keras
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import cifar10

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0

model = VGG16(input_shape=(32, 32, 3))
print(model.summary())
optimizer = Adam(learning_rate=0.0001)
model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=["accuracy"],
)

model.fit(x_train,
          y_train,
          validation_data=(x_test, y_test),
          epochs=5,
          batch_size=32)
print("Done")
Ejemplo n.º 24
0
def main():
    model_name = 'model/PFA_00050.h5'
    model_input = Input(shape=(target_size[0], target_size[1], 3))
    model = VGG16(model_input,
                  dropout=dropout,
                  with_CPFE=with_CPFE,
                  with_CA=with_CA,
                  with_SA=with_SA)
    model.load_weights(model_name, by_name=True)

    for layer in model.layers:
        layer.trainable = False
    '''
    image_path = 'image/2.jpg'
    img, shape = load_image(image_path)
    img = np.array(img, dtype=np.float32)
    sa = model.predict(img)
    sa = getres(sa, shape)
    plt.title('saliency')
    plt.subplot(131)
    plt.imshow(cv2.imread(image_path))
    plt.subplot(132)
    plt.imshow(sa,cmap='gray')
    plt.subplot(133)
    edge = laplace_edge(sa)
    plt.imshow(edge,cmap='gray')
    plt.savefig(os.path.join('./train_1000_output','alpha.png'))
    #misc.imsave(os.path.join('./train_1000_output','alpha.png'), sa)
    '''
    #HOME = os.path.expanduser('~')
    #rgb_folder = os.path.join(HOME, 'data/sku_wdis_imgs/sku_wdis_imgs_12')
    rgb_folder = './tmp'
    output_folder = './train_1000_output'
    rgb_names = os.listdir(rgb_folder)
    print(rgb_folder, "\nhas {0} pics.".format(len(rgb_names)))
    start = time.time()
    for rgb_name in rgb_names:
        if rgb_name[-4:] == '.jpg':
            img_org = misc.imread(os.path.join(rgb_folder, rgb_name))
            img, shape = load_image(os.path.join(rgb_folder, rgb_name))
            img = np.array(img, dtype=np.float32)
            sa = model.predict(img)
            sa = getres(sa, shape)
            misc.imsave(
                os.path.join(output_folder, rgb_name[:-4] + '_mask1.png'), sa)
            #1. densecrf
            if with_crf:
                sa = dense_crf(np.expand_dims(np.expand_dims(sa, 0), 3),
                               np.expand_dims(img_org, 0))
                sa = sa[0, :, :, 0]
            #2. reduce contain relationship
            threshold_gray = 2
            threshold_area = 100
            connectivity = 8
            sa = sa.astype(np.uint8)
            _, sa = cv2.threshold(sa, threshold_gray, 255, 0)
            output = cv2.connectedComponentsWithStats(sa, connectivity,
                                                      cv2.CV_32S)
            stats = output[2]
            area_img = img_org.shape[0] * img_org.shape[1]
            for rgns in range(1, stats.shape[0]):
                if area_img / stats[rgns, 4] <= threshold_area:
                    continue
                x1, y1 = stats[rgns, 0], stats[rgns, 1]
                x2, y2 = x1 + stats[rgns, 2], y1 + stats[rgns, 3]
                sa[y1:y2, x1:x2] = 0
            img_seg = np.zeros(img_org.shape[:3])
            _, cnts, hierarchy = cv2.findContours(sa, cv2.RETR_EXTERNAL,
                                                  cv2.CHAIN_APPROX_TC89_KCOS)
            img_seg = cv2.drawContours(img_seg, cnts, -1, (255, 255, 255), -1)
            misc.imsave(
                os.path.join(output_folder, rgb_name[:-4] + '_mask2.png'),
                img_seg)
            if draw_bound:
                #   Changing the connected components to bounding boxes
                # [ref1](https://blog.csdn.net/qq_21997625/article/details/86558178)
                img_org_bound = img_org.copy()
                area_img = img_org_bound.shape[0] * img_org_bound.shape[1]
                for rgns in range(1, stats.shape[0]):
                    if area_img / stats[rgns, 4] > threshold_area:
                        continue
                    x1, y1 = stats[rgns, 0], stats[rgns, 1]
                    x2, y2 = x1 + stats[rgns, 2], y1 + stats[rgns, 3]
                    cv2.rectangle(img_org_bound, (x1, y1), (x2, y2),
                                  (255, 0, 0),
                                  thickness=2)
                cv2.imwrite(
                    os.path.join(output_folder, rgb_name[:-4] + '_bbox.png'),
                    img_org_bound[..., ::-1])
                #print(os.path.join(output_folder, rgb_name[:-4]+'.png'))
            if draw_poly:
                # [0](https://blog.csdn.net/sunny2038/article/details/12889059)
                # [1](https://blog.csdn.net/jjddss/article/details/73527990)
                img_org_poly = img_org.copy()
                if len(cnts) <= 0:
                    continue
                for i in range(len(cnts)):
                    cnt = cnts[i]
                    _n = cnt.shape[0]
                    if _n <= 2:
                        continue
                    cnt = cnt.reshape((_n, 2))
                    cnt = tuple(map(tuple, cnt))
                    for j in range(_n):
                        img_org_poly = cv2.drawMarker(
                            img_org_poly,
                            cnt[j], (255, 0, 0),
                            markerType=cv2.MARKER_SQUARE,
                            markerSize=2,
                            thickness=2,
                            line_type=cv2.FILLED)
                cv2.imwrite(
                    os.path.join(output_folder, rgb_name[:-4] + '_poly.png'),
                    img_org_poly[..., ::-1])
            if draw_cutout:
                mask_bbox_change_size = cv2.resize(
                    img_seg, (img_org.shape[1], img_org.shape[0]))
                object_image = np.zeros(img_org.shape[:3], np.uint8)
                object_image = np.where(mask_bbox_change_size > 0, img_org,
                                        object_image)
                #for i in range(3):
                #    object_image[:,:,i] = np.where(mask_bbox_change_size>0, img_org[:,:,i], object_image[:,:,i])
                misc.imsave(
                    os.path.join(output_folder, rgb_name[:-4] + '_tmp.png'),
                    object_image)
                background_transparent(
                    os.path.join(output_folder, rgb_name[:-4] + '_tmp.png'),
                    os.path.join(output_folder, rgb_name[:-4] + '_trans.png'))
    end = time.time()
    print("processing done in: %.4f time" % (end - start))
Ejemplo n.º 25
0
def test(FLAG):
    print("Reading dataset...")
    if FLAG.dataset == 'CIFAR-10':
        test_data = CIFAR10(train=False)
        vgg16 = VGG16(classes=10)
    elif FLAG.dataset == 'CIFAR-100':
        test_data = CIFAR100(train=False)
        vgg16 = VGG16(classes=100)
    else:
        raise ValueError("dataset should be either CIFAR-10 or CIFAR-100.")

    Xtest, Ytest = test_data.test_data, test_data.test_labels

    if FLAG.fidelity is not None:
        data_dict = np.load(FLAG.init_from, encoding='latin1').item()
        data_dict = dpSparsifyVGG16(data_dict, FLAG.fidelity)
        vgg16.build(vgg16_npy_path=data_dict,
                    conv_pre_training=True,
                    fc_pre_training=True)
        print("Build model from %s using dp=%s" %
              (FLAG.init_from, str(FLAG.fidelity * 100)))
    else:
        vgg16.build(vgg16_npy_path=FLAG.init_from,
                    conv_pre_training=True,
                    fc_pre_training=True)
        print("Build full model from %s" % (FLAG.init_from))

    # build model using  dp
    # dp = [(i+1)*0.05 for i in range(1,20)]
    dp = [1.0]
    vgg16.set_idp_operation(dp=dp, keep_prob=FLAG.keep_prob)

    flops, params = countFlopsParas(vgg16)
    print("Flops: %3f M, Paras: %3f M" % (flops / 1e6, params / 1e6))
    FLAG.flops_M = flops / 1e6
    FLAG.params_M = params / 1e6

    with tf.Session() as sess:
        if FLAG.save_dir is not None:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(FLAG.save_dir)

            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, checkpoint)
                print("Model restored %s" % checkpoint)
                sess.run(tf.global_variables())
            else:
                print("No model checkpoint in %s" % FLAG.save_dir)
        else:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.global_variables())
        print("Initialized")
        output = []
        for dp_i in dp:
            accu = sess.run(vgg16.accu_dict[str(int(dp_i * 100))],
                            feed_dict={
                                vgg16.x: Xtest[:5000, :],
                                vgg16.y: Ytest[:5000, :],
                                vgg16.is_train: False
                            })
            accu2 = sess.run(vgg16.accu_dict[str(int(dp_i * 100))],
                             feed_dict={
                                 vgg16.x: Xtest[5000:, :],
                                 vgg16.y: Ytest[5000:, :],
                                 vgg16.is_train: False
                             })
            output.append((accu + accu2) / 2)
            print("At DP={dp:.4f}, accu={perf:.4f}".format(
                dp=dp_i * FLAG.fidelity, perf=(accu + accu2) / 2))
        res = pd.DataFrame.from_dict({
            'DP': [int(dp_i * 100) for dp_i in dp],
            'accu': output
        })
        res.to_csv(FLAG.output, index=False)
        print("Write into %s" % FLAG.output)

    FLAG.accuracy = (accu + accu2) / 2

    header = ''
    row = ''
    for key in sorted(vars(FLAG)):
        if header is '':
            header = key
            row = str(getattr(FLAG, key))
        else:
            header += "," + key
            row += "," + str(getattr(FLAG, key))
    row += "\n"
    header += "\n"
    if os.path.exists("/home/cmchang/new_CP_CNN/performance.csv"):
        with open("/home/cmchang/new_CP_CNN/performance.csv", "a") as myfile:
            myfile.write(row)
    else:
        with open("/home/cmchang/new_CP_CNN/performance.csv", "w") as myfile:
            myfile.write(header)
            myfile.write(row)
Ejemplo n.º 26
0
def main():
   # Training settings
   parser = argparse.ArgumentParser(description='Prography-6th-assignment-HyunjinKim')
   parser.add_argument('--dataroot', default="/input/" ,help='path to dataset')
   parser.add_argument('--evalf', default="/eval/" ,help='path to evaluate sample')
   parser.add_argument('--outf', default='models',
           help='folder to output images and model checkpoints')
   parser.add_argument('--ckpf', default='',
           help="path to model checkpoint file (to continue training)")
           
   #### Batch size ####
   parser.add_argument('--batch-size', type=int, default=4, metavar='N',
           help='input batch size for training (default: 4)')
   parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
           help='input batch size for testing (default: 4)')
   #### Epochs ####
   parser.add_argument('--epochs', type=int, default=10, metavar='N',
           help='number of epochs to train (default: 10)')

   parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
           help='learning rate (default: 0.01)')
   parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
           help='SGD momentum (default: 0.5)')
   parser.add_argument('--no-cuda', action='store_true', default=False,
           help='disables CUDA training')
   parser.add_argument('--seed', type=int, default=1, metavar='S',
           help='random seed (default: 1)')
   parser.add_argument('--log-interval', type=int, default=10, metavar='N',
           help='how many batches to wait before logging training status')
   parser.add_argument('--train', action='store_true',
           help='training a VGG16 modified model on MNIST dataset')
   parser.add_argument('--evaluate', action='store_true',
           help='evaluate a [pre]trained model')


   args = parser.parse_args()

   # use CUDA?
   use_cuda = not args.no_cuda and torch.cuda.is_available()
   device = torch.device("cuda" if use_cuda else "cpu")
   kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}


   # transform to rgb
   rgb_tranform = transforms.Compose([
      transforms.Resize((224,224)),
      transforms.ToTensor(),
      transforms.Lambda(lambda x: torch.cat([x, x, x], 0)),
      transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)),
       ])


   # MNIST Dataset(for training)
   train_dataset = datasets.MNIST(root='./data/',
                            train=True,
                            transform=rgb_tranform,
                            download=True)
   train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)


   # MNIST Dataset(for test)
   test_dataset = datasets.MNIST(root='./data/',
                           train=False,
                           transform=rgb_tranform)

   test_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size=args.test_batch_size, shuffle=True, **kwargs)

   model = VGG16().to(device)
   print("model : ", model)

   optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)


   for epoch in range(1, args.epochs + 1) :
      train(args, model, device, train_loader, optimizer, epoch)
      test(args, model, device, test_loader, epoch)
      torch.save(model.state_dict(), '/content/drive/My Drive/prography/model/vgg16_model_epoch_%d.pth' % (epoch))
Ejemplo n.º 27
0

def inference(model, sample_path):
    sample_dataset = glob.glob(sample_path)

    for image in sample_dataset:
        img = Image.open(image)
        img = rgb_transform(img)  # size : [3,224,224]
        img = img.unsqueeze(0)  # size : [1,3,224,224]
        #print("size of extended image : ", img.size())
        input_image = img.to(device)

        output = model(input_image)
        prediction = output.max(dim=1)[1].item()
        img2 = mpimg.imread(image)
        plt.imshow(img2)
        plt.title("prediction : " + str(prediction))
        plt.show()

        print("Prediction result : ", prediction)


model = VGG16().to(device)
model.load_state_dict(
    torch.load(
        "/content/drive/My Drive/prography/model/vgg16_model_epoch_2.pth"))

sample_image_path = "/content/drive/My Drive/prography/sample/*jpg"  # put sample image to sample file as jpg extension

inference(model, sample_image_path)
Ejemplo n.º 28
0
def image_predict(image):

    image_path = os.path.join(arg.input_folder, image)
    src_img = np.array(Image.open(image_path))
    face_img, face_coor = image_face_detect.face_d(src_img)
    gray = cv2.cvtColor(face_img, cv2.COLOR_RGB2GRAY)
    gray = cv2.resize(gray, (240, 240))
    img = gray[:, :, np.newaxis]
    img = np.concatenate((img, img, img), axis=2)
    img = Image.fromarray(np.uint8(img))
    inputs = transform_test(img)

    class_names = [
        'Angry', 'Disgusted', 'Fearful', 'Happy', 'Sad', 'Surprised', 'Neutral'
    ]

    net = VGG16.Net()
    checkpoint = torch.load(arg.ckpt_path)
    net.load_state_dict(checkpoint['net'])
    if use_cuda:
        net.cuda()
    net.eval()
    ncrops, c, h, w = np.shape(inputs)
    inputs = inputs.view(-1, c, h, w)

    inputs = Variable(inputs, volatile=True)
    if use_cuda:
        inputs = inputs.to(device)
    outputs = net(inputs)
    outputs_avg = outputs.view(ncrops, -1).mean(0)
    score = F.softmax(outputs_avg)
    _, predicted = torch.max(outputs_avg.data, 0)
    expression = class_names[int(predicted.cpu().numpy())]
    if face_coor is not None:
        [x, y, w, h] = face_coor
        cv2.rectangle(src_img, (x, y), (x + w, y + h), (255, 0, 0), 2)

    plt.rcParams['figure.figsize'] = (11, 6)
    axes = plt.subplot(1, 2, 1)
    plt.imshow(src_img)
    plt.title('Input Image', fontsize=20)
    axes.set_xticks([])
    axes.set_yticks([])
    plt.tight_layout()
    plt.subplots_adjust(left=0.05,
                        bottom=0.2,
                        right=0.95,
                        top=0.9,
                        hspace=0.02,
                        wspace=0.3)
    plt.subplot(1, 2, 2)
    ind = 0.1 + 0.6 * np.arange(len(class_names))
    width = 0.4
    for i in range(len(class_names)):
        plt.bar(ind[i], score.data.cpu().numpy()[i], width, color='orangered')

    plt.title("Result Analysis", fontsize=20)
    plt.xticks(ind, class_names, rotation=30, fontsize=12)
    if arg.SAVE_FLAG:
        if not os.path.exists(arg.output_folder):
            os.mkdir('./image_output')
        save_path = os.path.join(arg.output_folder, image)
        plt.savefig(save_path + '-result' + '.jpg')
    else:
        if arg.show_resultimg:
            plt.show()

    print("The Expression is %s" % expression)
Ejemplo n.º 29
0
def train(FLAG):
    print("Reading dataset...")
    # load data
    Xtrain, Ytrain = read_images(TRAIN_DIR), read_masks(TRAIN_DIR, onehot=True)
    Xtest, Ytest = read_images(VAL_DIR), read_masks(VAL_DIR, onehot=True)
    track = [
        "hw3-train-validation/validation/0008",
        "hw3-train-validation/validation/0097",
        "hw3-train-validation/validation/0107"
    ]
    Xtrack, Ytrack = read_list(track)

    vgg16 = VGG16(classes=7, shape=(256, 256, 3))
    vgg16.build(vgg16_npy_path=FLAG.init_from,
                mode=FLAG.mode,
                keep_prob=FLAG.keep_prob)

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    checkpoint_path = os.path.join(FLAG.save_dir, 'model.ckpt')

    def initialize_uninitialized(sess):
        global_vars = tf.global_variables()
        is_not_initialized = sess.run(
            [tf.is_variable_initialized(var) for var in global_vars])
        not_initialized_vars = [
            v for (v, f) in zip(global_vars, is_not_initialized) if not f
        ]
        if len(not_initialized_vars):
            sess.run(tf.variables_initializer(not_initialized_vars))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # hyper parameters
        batch_size = 32
        epoch = 500
        early_stop_patience = 50
        min_delta = 0.0001
        opt_type = 'adam'

        # recorder
        epoch_counter = 0

        # optimizer
        global_step = tf.Variable(0, trainable=False)

        # Passing global_step to minimize() will increment it at each step.
        if opt_type is 'sgd':
            start_learning_rate = FLAG.lr
            half_cycle = 2000
            learning_rate = tf.train.exponential_decay(start_learning_rate,
                                                       global_step,
                                                       half_cycle,
                                                       0.5,
                                                       staircase=True)
            opt = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                             momentum=0.9,
                                             use_nesterov=True)
        else:
            start_learning_rate = FLAG.lr
            half_cycle = 2000
            learning_rate = tf.train.exponential_decay(start_learning_rate,
                                                       global_step,
                                                       half_cycle,
                                                       0.5,
                                                       staircase=True)
            opt = tf.train.AdamOptimizer(learning_rate=learning_rate)

        obj = vgg16.loss
        train_op = opt.minimize(obj, global_step=global_step)

        # progress bar
        ptrain = IntProgress()
        pval = IntProgress()
        display(ptrain)
        display(pval)
        ptrain.max = int(Xtrain.shape[0] / batch_size)
        pval.max = int(Xtest.shape[0] / batch_size)

        # re-initialize
        initialize_uninitialized(sess)

        # reset due to adding a new task
        patience_counter = 0
        current_best_val_loss = np.float('Inf')

        # optimize when the aggregated obj
        while (patience_counter < early_stop_patience
               and epoch_counter < epoch):

            # start training
            stime = time.time()
            bar_train = Bar(
                'Training',
                max=int(Xtrain.shape[0] / batch_size),
                suffix='%(index)d/%(max)d - %(percent).1f%% - %(eta)ds')
            bar_val = Bar(
                'Validation',
                max=int(Xtest.shape[0] / batch_size),
                suffix='%(index)d/%(max)d - %(percent).1f%% - %(eta)ds')

            train_loss, train_accu = 0.0, 0.0
            for i in range(int(Xtrain.shape[0] / batch_size)):
                st = i * batch_size
                ed = (i + 1) * batch_size
                loss, accu, _ = sess.run(
                    [obj, vgg16.accuracy, train_op],
                    feed_dict={
                        vgg16.x: Xtrain[st:ed, :],
                        vgg16.y: Ytrain[st:ed, :],
                        vgg16.is_train: True
                    })
                train_loss += loss
                train_accu += accu
                ptrain.value += 1
                ptrain.description = "Training %s/%s" % (ptrain.value,
                                                         ptrain.max)
            train_loss = train_loss / ptrain.value
            train_accu = train_accu / ptrain.value

            # validation
            val_loss = 0
            val_accu = 0
            for i in range(int(Xtest.shape[0] / batch_size)):
                st = i * batch_size
                ed = (i + 1) * batch_size
                loss, accu = sess.run(
                    [obj, vgg16.accuracy],
                    feed_dict={
                        vgg16.x: Xtest[st:ed, :],
                        vgg16.y: Ytest[st:ed, :],
                        vgg16.is_train: False
                    })
                val_loss += loss
                val_accu += accu
                pval.value += 1
                pval.description = "Testing %s/%s" % (pval.value, pval.value)
            val_loss = val_loss / pval.value
            val_accu = val_accu / pval.value

            # plot
            if epoch_counter % 10 == 0:
                Xplot = sess.run(vgg16.pred,
                                 feed_dict={
                                     vgg16.x: Xtrack[:, :],
                                     vgg16.y: Ytrack[:, :],
                                     vgg16.is_train: False
                                 })

                for i, fname in enumerate(track):
                    saveimg = skimage.transform.resize(Xplot[i],
                                                       output_shape=(512, 512),
                                                       order=0,
                                                       preserve_range=True,
                                                       clip=False)
                    saveimg = label2rgb(saveimg)
                    imageio.imwrite(
                        os.path.join(
                            FLAG.save_dir,
                            os.path.basename(fname) + "_pred_" +
                            str(epoch_counter) + ".png"), saveimg)
                    print(
                        os.path.join(
                            FLAG.save_dir,
                            os.path.basename(fname) + "_pred_" +
                            str(epoch_counter) + ".png"))

            # early stopping check
            if (current_best_val_loss - val_loss) > min_delta:
                current_best_val_loss = val_loss
                patience_counter = 0
                saver.save(sess, checkpoint_path, global_step=epoch_counter)
                print("save in %s" % checkpoint_path)
            else:
                patience_counter += 1

            # shuffle Xtrain and Ytrain in the next epoch
            idx = np.random.permutation(Xtrain.shape[0])
            Xtrain, Ytrain = Xtrain[idx, :, :, :], Ytrain[idx, :]

            # epoch end
            epoch_counter += 1

            ptrain.value = 0
            pval.value = 0
            bar_train.finish()
            bar_val.finish()

            print(
                "Epoch %s (%s), %s sec >> train loss: %.4f, train accu: %.4f, val loss: %.4f, val accu: %.4f"
                % (epoch_counter, patience_counter,
                   round(time.time() - stime,
                         2), train_loss, train_accu, val_loss, val_accu))
 def _get_vgg_model(self, pretrained_checkpoint: Optional[str]=None) -> VGG16:
     vgg16 = VGG16(10)
     if (pretrained_checkpoint is not None):
         print("[INFO] Loading pretrained weights. ({})".format(pretrained_checkpoint))
         vgg16.load_state_dict(torch.load(pretrained_checkpoint))
     return vgg16