def main():
    # Load best model
    checkpoint_path = './model_checkpoint.pth'
    model = MobileNet().to('cuda')
    checkpoint = torch.load(checkpoint_path)
    print("model load successfully.")

    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    image_set = Image.open('./test.jpg')

    image_tensor = transform(image_set)
    image = torch.unsqueeze(image_tensor, 0)
    image = image.to('cuda')
    output = model(image)
    result = output > 0.5
    result = result.cpu().numpy()

    for t in range(len(attributes)):
        if result[0][t] == True:
            print("Attribute: \033[1;35m%s \033[0m, \033[1;35m%s \033[0m" %
                  (attributes[t], result[0][t]))
        else:
            print("Attribute: %s, %s" % (attributes[t], result[0][t]))
Exemplo n.º 2
0
def main():
    # Parse the JSON arguments
    try:
        config_args = parse_args()
    except:
        print("Add a config file using \'--config file_name.json\'")
        exit(1)

    # Create the experiment directories
    _, config_args.summary_dir, config_args.checkpoint_dir = create_experiment_dirs(
        config_args.experiment_dir)

    # Reset the default Tensorflow graph
    tf.reset_default_graph()

    # Tensorflow specific configuration
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # Data loading
    data = DataLoader(config_args.batch_size, config_args.shuffle)
    print("Loading Data...")
    config_args.img_height, config_args.img_width, config_args.num_channels, \
    config_args.train_data_size, config_args.test_data_size = data.load_data()
    print("Data loaded\n\n")

    # Model creation
    print("Building the model...")
    model = MobileNet(config_args)
    print("Model is built successfully\n\n")

    # Summarizer creation
    summarizer = Summarizer(sess, config_args.summary_dir)
    # Train class

    trainer = Train(sess, model, data, summarizer)

    if config_args.to_train:
        try:
            print("Training...")
            start_time = time.time()
            trainer.train()
            print("Training Finished\n\n")
            print('training 시간 :', time.time() - start_time)
        except KeyboardInterrupt:
            trainer.save_model()

    if config_args.to_test:
        print("Final test!")
        trainer.test('val')
        print("Testing Finished\n\n")
def main():
    """
    Script entrypoint
    """
    t_start = datetime.now()
    header = ["Start Time", "End Time", "Duration (s)"]
    row = [t_start.strftime(DEFAULT_DATE_TIME_FORMAT)]

    dnn = MobileNet()

    # show class indices
    print('****************')
    for cls, idx in dnn.train_batches.class_indices.items():
        print('Class #{} = {}'.format(idx, cls))
    print('****************')

    print(dnn.model.summary())

    dnn.train(t_start,
              epochs=dnn.num_epochs,
              batch_size=dnn.batch_size,
              training=dnn.train_batches,
              validation=dnn.valid_batches)

    # save trained weights
    dnn.model.save(dnn.file_weights + 'old')

    dnn.model.save_weights(dnn.file_weights)
    with open(dnn.file_architecture, 'w') as f:
        f.write(dnn.model.to_json())

    t_end = datetime.now()
    difference_in_seconds = get_difference_in_seconds(t_start, t_end)

    row.append(t_end.strftime(DEFAULT_DATE_TIME_FORMAT))
    row.append(str(difference_in_seconds))

    append_row_to_csv(complete_run_timing_file, header)
    append_row_to_csv(complete_run_timing_file, row)
Exemplo n.º 4
0
def main():
    model = MobileNet(args.conv, args.fc)

    if args.mode == 'train':
        print('Training ...')
        train_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop(44),
            transforms.RandomHorizontalFlip(),
            transforms.RandomAffine(degrees=10, scale=(0.9, 1.1)),
            transforms.ToTensor()
        ])
        valid_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop(44),
            transforms.ToTensor()
        ])
        train_set = TrainDataset(args.dataset,
                                 mode='train',
                                 transform=train_transform)
        valid_set = TrainDataset(args.dataset,
                                 mode='valid',
                                 transform=valid_transform)
        train_data = DataLoader(dataset=train_set,
                                batch_size=args.bs,
                                shuffle=True)
        valid_data = DataLoader(dataset=valid_set, batch_size=args.bs)

        manager = Manager(model, args)
        manager.train(train_data, valid_data)

    else:
        print('Predicting ...')
        test_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.TenCrop(44),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops]))
        ])
        test_set = TestDataset(args.dataset, transform=test_transform)
        test_data = DataLoader(dataset=test_set, batch_size=args.bs)

        manager = Manager(model, args)
        manager.predict(test_data)
Exemplo n.º 5
0
def get_mobilenet_model(pretain = True,num_classes = 5,requires_grad = True):
	# 返回去掉了全连接层的mobilenet
	model = MobileNet()
	# 不训练这几层
	for param in model.parameters():
		param.requires_grad = requires_grad

	if pretain:
		# Todo: load the pre-trained model for self.base_net, it will increase the accuracy by fine-tuning
		basenet_state = torch.load("/home/pzl/object-localization/pretained/mobienetv2.pth")
		# filter out unnecessary keys
		model_dict = model.state_dict()
		pretrained_dict = {k: v for k, v in basenet_state.items() if k in model_dict}
		# load the new state dict
		model.load_state_dict(pretrained_dict)
		return model
	else:
		return model
def load():
	tra_i=0
	tes_i=0
	datas = os.listdir('./data')
	print(datas)
	for e in datas:
		img = cv2.imread('./data/'+e)
		if e[0] == 'p':
			img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
			x_train[tra_i] = img
			y_train.append([0])
			tra_i+=1
		if e[0] == 'a':
			img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
			x_train[tra_i] = img
			y_train.append([1])
			tra_i+=1
	return (x_train,np.array(y_train)) , (x_test,np.array(y_test))

(x_train, y_train), (x_test, y_test) = load()
print(x_train.shape)
print(y_train.shape)
y_train = keras.utils.to_categorical(y_train, num_classes)
x_train /= 255
print(x_train.shape)
print(y_train.shape)
img_input = keras.layers.Input(shape=(224, 224, 3))
model = MobileNet(input_tensor=img_input, classes=num_classes)
model.summary()
model.compile(loss='categorical_crossentropy',optimizer = keras.optimizers.rmsprop(lr=0.0004, decay=5e-4),metrics=['accuracy'])
model.fit(x_train, y_train, validation_split =0.2,batch_size = batch_size, epochs = epochs, verbose = 1,shuffle=True)
Exemplo n.º 7
0
from utils import validation, cal_dis
from dataloader import VeRi_dataloader, VeRI_validation_dataloader
from losss import BatchHardTripletLoss

################## config ################
device = torch.device("cuda:7")
date = time.strftime("%m-%d", time.localtime())
#date = "03-14"

model_path = "/home/lxd/checkpoints/" + date

model_name = sys.argv[1]
if model_name == "vgg16":
    model = Vgg16Net()
elif model_name == "mobile":
    model = MobileNet()
elif model_name == "alexnet":
    model = AlexNet()
elif model_name == "res50":
    model = ResNet50()
elif model_name == "res34":
    model = ResNet34()
elif model_name == "vgg11":
    model = Vgg11Net()
else:
    print("Moddel Wrong")
model.to(device)

# train/test
loss_name = sys.argv[2]
batch = sys.argv[3]
    # trainer.dectect(FaceCropper().generate('fake.png'))


if __name__ == '__main__':
    # main()
    config_args = parse_args()
    config_args.img_height, config_args.img_width, config_args.num_channels = (
        224, 224, 3)
    _, config_args.summary_dir, config_args.checkpoint_dir = create_experiment_dirs(
        config_args.experiment_dir)
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    faces = FaceCropper().generate('maxresdefault.jpg')
    with tf.Session(config=config) as sess:
        config_args.batch_size = len(faces)
        model = MobileNet(config_args)
        sess.run(
            tf.group(tf.global_variables_initializer(),
                     tf.local_variables_initializer()))
        saver = tf.train.Saver(max_to_keep=config_args.max_to_keep,
                               keep_checkpoint_every_n_hours=10,
                               save_relative_paths=True)
        saver.restore(sess,
                      tf.train.latest_checkpoint(config_args.checkpoint_dir))

        # show camera
        face_cascade = cv2.CascadeClassifier(
            'haarcascade_frontalface_default.xml')
        cap = cv2.VideoCapture(0)
        while True:
            _, img = cap.read()
Exemplo n.º 9
0
def main(_):
    conf = configure()
    model = MobileNet(conf)
    run(model)
Exemplo n.º 10
0
from model import Vgg16Net, MobileNet, ResNet50, Vgg11Net, ResNet34, AlexNet
from utils import validation
from dataloader import VeRi_dataloader
from losss import BatchHardTripletLoss, contrastive_loss

################## config ################
date = time.strftime("%m-%d", time.localtime())
model_path = "/home/lxd/checkpoints/" + date

model_name = sys.argv[1]
if model_name == "vgg16":
    model = Vgg16Net()
elif model_name == "alexnet":
    model = AlexNet()
elif model_name == "mobile":
    model = MobileNet()
elif model_name == "res50":
    model = ResNet50()
elif model_name == "res34":
    model = ResNet34()
elif model_name == "vgg11":
    model = Vgg11Net()
else:
    print("Moddel Wrong")


gpu = sys.argv[2]
device = torch.device("cuda:{}".format(gpu))
model.to(device)

dataloader = VeRi_dataloader()
Exemplo n.º 11
0
        convFiles[num] = (str(convFiles[num]) + f'.{extension}')

    return convFiles


# Sorts images and labels into arrays
imgfiles = imgsort(os.listdir(args.img_dir), extension='png')

with open(args.label_dir, 'r') as f:
    labels = f.read().split(',')
    labels[len(labels) - 1] = labels[len(labels) - 1][:-1]
    labels = np.expand_dims(np.array(labels), axis=1)

print(imgfiles)
print(labels)
print(labels.shape)

print("Parsing image files...")
imgs = np.array([cv2.imread(f'{args.img_dir}{img}')
                 for img in imgfiles]) / 255.0
print(imgs.shape)

print("Loading Model...")
model = MobileNet(classes=16)
model.compile(optimizer='adam', loss='categorical_crossentropy')
model.fit(imgs, labels, epochs=EPOCHS, batch_size=BATCH_SIZE)

outputs = model.predict(imgs[:3])

print(f"outputs:\n{[letters(i) for i in outputs]}")
Exemplo n.º 12
0
class HyperTrain(Trainable):
    def _get_dataset(self, name):

        normalize = transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )

        if name == 'FashionMNIST':

            data_transforms = transforms.Compose([
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(), normalize
            ])
            dataset = torchvision.datasets.FashionMNIST(
                root="/home/kn15263s/data/FashionMNIST",
                transform=data_transforms)
            num_classes = 10
            input_size = 512 * 1 * 1

            return dataset, num_classes, input_size

        elif name == 'KMNIST':

            data_transforms = transforms.Compose([
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(), normalize
            ])

            dataset = torchvision.datasets.KMNIST(
                root="/home/kn15263s/data/KMNIST",
                transform=data_transforms,
                download=True)
            num_classes = 10
            input_size = 512 * 1 * 1

            return dataset, num_classes, input_size

        elif name == 'CIFAR10':

            data_transforms = transforms.Compose(
                [transforms.ToTensor(), normalize])
            dataset = torchvision.datasets.CIFAR10(
                root="/home/kn15263s/data/CIFAR10/", transform=data_transforms)
            num_classes = 10
            input_size = 512 * 1 * 1

            return dataset, num_classes, input_size

        elif name == 'SVHN':

            data_transforms = transforms.Compose(
                [transforms.ToTensor(), normalize])
            dataset = torchvision.datasets.SVHN(
                root="/home/kn15263s/data/SVHN/", transform=data_transforms)
            num_classes = 10
            input_size = 512 * 1 * 1

            return dataset, num_classes, input_size

        elif name == 'STL10':

            data_transforms = transforms.Compose(
                [transforms.ToTensor(), normalize])
            dataset = torchvision.datasets.STL10(
                root="/home/kn15263s/data/STL10/", transform=data_transforms)
            num_classes = 10
            input_size = 512 * 3 * 3

            return dataset, num_classes, input_size

        # elif name == 'Food':
        #
        #     class Food(Dataset):
        #
        #         def __init__(self, files, class_names, transform=transforms.ToTensor()):
        #
        #             self.data = files
        #             self.transform = transform
        #             self.class_names = class_names
        #
        #         def __getitem__(self, idx):
        #             img = Image.open(self.data[idx]).convert('RGB')
        #             name = self.data[idx].split('/')[-2]
        #             y = self.class_names.index(name)
        #             img = self.transform(img)
        #             return img, y
        #
        #         def __len__(self):
        #             return len(self.data)
        #
        #     data_transforms = transforms.Compose([
        #         transforms.RandomHorizontalFlip(),
        #         transforms.RandomVerticalFlip(),
        #         transforms.Resize((224, 224)),
        #         transforms.ToTensor(),
        #         normalize])
        #
        #     path = '/home/willy-huang/workspace/data/food'
        #     files_training = glob(os.path.join(path, '*/*.jpg'))
        #     class_names = []
        #
        #     for folder in os.listdir(os.path.join(path)):
        #         class_names.append(folder)
        #
        #     num_classes = len(class_names)
        #     dataset = Food(files_training, class_names, data_transforms)
        #     input_size = 512 * 7 * 7
        #
        #     return dataset, num_classes, input_size
        #
        # elif name == 'Stanford_dogs':
        #
        #     class Stanford_dogs(Dataset):
        #
        #         def __init__(self, files, class_names, transform=transforms.ToTensor()):
        #
        #             self.data = files
        #             self.transform = transform
        #             self.class_names = class_names
        #
        #         def __getitem__(self, idx):
        #             img = Image.open(self.data[idx]).convert('RGB')
        #             name = self.data[idx].split('/')[-2]
        #             y = self.class_names.index(name)
        #             img = self.transform(img)
        #             return img, y
        #
        #         def __len__(self):
        #             return len(self.data)
        #
        #
        #     data_transforms = transforms.Compose([
        #         transforms.RandomHorizontalFlip(),
        #         transforms.RandomVerticalFlip(),
        #         transforms.Resize((224, 224)),
        #         transforms.ToTensor(),
        #         normalize])
        #
        #     path = '/home/willy-huang/workspace/data/stanford_dogs'
        #     files_training = glob(os.path.join(path, '*/*.jpg'))
        #     class_names = []
        #
        #     for folder in os.listdir(os.path.join(path)):
        #         class_names.append(folder)
        #
        #     num_classes = len(class_names)
        #     dataset = Stanford_dogs(files_training, class_names, data_transforms)
        #     input_size = 512 * 7 * 7
        #
        #     return dataset, num_classes, input_size

    def _setup(self, config):
        random.seed(50)
        np.random.seed(50)
        torch.cuda.manual_seed_all(50)
        torch.manual_seed(50)
        self.total_time = time.time()
        self.name = args.Dataset_name
        nnArchitecture = args.Network_name

        dataset, num_class, input_size = self._get_dataset(self.name)

        num_total = len(dataset)
        shuffle = np.random.permutation(num_total)
        split_val = int(num_total * 0.2)

        train_idx, valid_idx = shuffle[split_val:], shuffle[:split_val]

        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)

        self.trainset_ld = DataLoader(dataset,
                                      batch_size=256,
                                      sampler=train_sampler,
                                      num_workers=4)
        self.validset_ld = DataLoader(dataset,
                                      batch_size=256,
                                      sampler=valid_sampler,
                                      num_workers=4)

        self.modelname = '{}--{}.pth.tar'.format(self.name, nnArchitecture)
        loggername = self.modelname.replace("pth.tar", "log")
        self.logger = utils.buildLogger(loggername)

        self.seed_table = np.array([
            "", "epoch", "lr", "momentum", "weight_decay", "factor", "outLoss",
            "accuracy"
        ])

        # ---- hyperparameters ----
        self.lr = config["lr"]
        self.momentum = config["momentum"]
        self.weight_decay = config["weight_decay"]
        self.factor = config["factor"]

        self.epochID = 0
        self.loss = nn.CrossEntropyLoss()
        self.accuracy = -999999999999.0

        # -------------------- SETTINGS: NETWORK ARCHITECTURE

        if nnArchitecture == 'Vgg11':
            self.model = Vgg11(num_class, input_size).cuda()

        elif nnArchitecture == 'Resnet18':
            self.model = Resnet18(num_class, input_size).cuda()

        elif nnArchitecture == 'MobileNet':
            self.model = MobileNet(num_class, input_size).cuda()

        elif nnArchitecture == 'MobileNet_V2':
            self.model = MobileNet_V2(num_class, input_size).cuda()

        else:
            self.model = None
            assert 0

        self.model = torch.nn.DataParallel(self.model).cuda()
        self.logger.info("Build Model Done")

        # -------------------- SETTINGS: OPTIMIZER & SCHEDULER --------------------
        self.optimizer = optim.SGD(filter(lambda x: x.requires_grad,
                                          self.model.parameters()),
                                   lr=self.lr,
                                   momentum=self.momentum,
                                   weight_decay=self.weight_decay,
                                   nesterov=False)

        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, factor=self.factor, patience=10, mode='min')

        self.logger.info("Build Optimizer Done")

    def _train_iteration(self):
        self.start_time = time.time()
        self.model.train()

        losstra = 0
        losstraNorm = 0

        for batchID, (input, target) in enumerate(self.trainset_ld):
            varInput = Variable(input).cuda()
            varTarget = Variable(target).cuda()
            varOutput = self.model(varInput)

            lossvalue = self.loss(varOutput, varTarget)

            losstra += lossvalue.item()
            losstraNorm += 1

            self.optimizer.zero_grad()
            lossvalue.backward()
            torch.nn.utils.clip_grad_value_(self.model.parameters(), 10)
            self.optimizer.step()

        self.trainLoss = losstra / losstraNorm

    def _test(self):

        self.model.eval()

        lossVal = 0
        lossValNorm = 0
        correct = 0

        num_samples = 0
        for batchID, (input, target) in enumerate(self.validset_ld):
            with torch.no_grad():
                varInput = Variable(input).cuda(async=True)
                varTarget = Variable(target).cuda(async=True)
                varOutput = self.model(varInput)

                losstensor = self.loss(varOutput, varTarget)

                pred = varOutput.argmax(1)
                correct += (pred == varTarget).sum().cpu()

                lossVal += losstensor.item()
                lossValNorm += 1
                num_samples += len(input)

        self.outLoss = lossVal / lossValNorm
        accuracy = correct.item() / num_samples

        self.scheduler.step(self.outLoss, epoch=self.epochID)

        if accuracy > self.accuracy:
            self.accuracy = accuracy

            torch.save(
                {
                    'epoch': self.epochID + 1,
                    'state_dict': self.model.state_dict(),
                    'loss': self.outLoss,
                    'best_accuracy': self.accuracy,
                    'optimizer': self.optimizer.state_dict(),
                }, "./best_" + self.modelname)

            save = np.array([
                self.seed_table,
                [
                    str(self.name),
                    str(self.epochID + 1),
                    str(self.lr),
                    str(self.momentum),
                    str(self.weight_decay),
                    str(self.factor),
                    str(self.outLoss),
                    str(self.accuracy)
                ]
            ])

            np.savetxt("./seed(50).csv", save, delimiter=',', fmt="%s")

        self.logger.info('Epoch [' + str(self.epochID + 1) +
                         '] loss= {:.5f}'.format(self.outLoss) +
                         ' ---- accuracy= {:.5f}'.format(accuracy) +
                         ' ---- best_accuracy= {:.5f}'.format(self.accuracy) +
                         ' ---- model: {}'.format(self.modelname) +
                         ' ---- time: {:.1f} s'.format((time.time() -
                                                        self.start_time)) +
                         ' ---- total_time: {:.1f} s'.format(
                             (time.time() - self.total_time)))

        self.epochID += 1
        return {
            "episode_reward_mean": accuracy,
            "neg_mean_loss": self.outLoss,
            "mean_accuracy": accuracy,
            "epoch": self.epochID,
            'mean_train_loss': self.trainLoss
        }

    def _train(self):
        self._train_iteration()
        return self._test()

    def _save(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir, "final_model.pth")
        torch.save(
            {
                "epoch": self.epochID,
                "best_accuracy": self.accuracy,
                'loss': self.outLoss,
                "state_dict": self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
            }, checkpoint_path)
        return checkpoint_path

    def _restore(self, checkpoint_path):
        self.model.load_state_dict(checkpoint_path)
Exemplo n.º 13
0
    def _setup(self, config):
        random.seed(50)
        np.random.seed(50)
        torch.cuda.manual_seed_all(50)
        torch.manual_seed(50)
        self.total_time = time.time()
        self.name = args.Dataset_name
        nnArchitecture = args.Network_name

        dataset, num_class, input_size = self._get_dataset(self.name)

        num_total = len(dataset)
        shuffle = np.random.permutation(num_total)
        split_val = int(num_total * 0.2)

        train_idx, valid_idx = shuffle[split_val:], shuffle[:split_val]

        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)

        self.trainset_ld = DataLoader(dataset,
                                      batch_size=256,
                                      sampler=train_sampler,
                                      num_workers=4)
        self.validset_ld = DataLoader(dataset,
                                      batch_size=256,
                                      sampler=valid_sampler,
                                      num_workers=4)

        self.modelname = '{}--{}.pth.tar'.format(self.name, nnArchitecture)
        loggername = self.modelname.replace("pth.tar", "log")
        self.logger = utils.buildLogger(loggername)

        self.seed_table = np.array([
            "", "epoch", "lr", "momentum", "weight_decay", "factor", "outLoss",
            "accuracy"
        ])

        # ---- hyperparameters ----
        self.lr = config["lr"]
        self.momentum = config["momentum"]
        self.weight_decay = config["weight_decay"]
        self.factor = config["factor"]

        self.epochID = 0
        self.loss = nn.CrossEntropyLoss()
        self.accuracy = -999999999999.0

        # -------------------- SETTINGS: NETWORK ARCHITECTURE

        if nnArchitecture == 'Vgg11':
            self.model = Vgg11(num_class, input_size).cuda()

        elif nnArchitecture == 'Resnet18':
            self.model = Resnet18(num_class, input_size).cuda()

        elif nnArchitecture == 'MobileNet':
            self.model = MobileNet(num_class, input_size).cuda()

        elif nnArchitecture == 'MobileNet_V2':
            self.model = MobileNet_V2(num_class, input_size).cuda()

        else:
            self.model = None
            assert 0

        self.model = torch.nn.DataParallel(self.model).cuda()
        self.logger.info("Build Model Done")

        # -------------------- SETTINGS: OPTIMIZER & SCHEDULER --------------------
        self.optimizer = optim.SGD(filter(lambda x: x.requires_grad,
                                          self.model.parameters()),
                                   lr=self.lr,
                                   momentum=self.momentum,
                                   weight_decay=self.weight_decay,
                                   nesterov=False)

        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, factor=self.factor, patience=10, mode='min')

        self.logger.info("Build Optimizer Done")
Exemplo n.º 14
0
def main():
    # Parse the JSON arguments
    try:
        config_args = parse_args()
    except:
        print("Add a config file using \'--config file_name.json\'")
        exit(1)

    # Create the experiment directories
    _, config_args.summary_dir, config_args.checkpoint_dir = create_experiment_dirs(
        config_args.experiment_dir)

    # Reset the default Tensorflow graph
    tf.reset_default_graph()

    # Tensorflow specific configuration
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # Data loading
    data = DataLoader(config_args.batch_size, config_args.shuffle)
    print("Loading Data...")
    config_args.img_height, config_args.img_width, config_args.num_channels, \
    config_args.train_data_size, config_args.test_data_size = data.load_data()
    print("Data loaded\n\n")

    # Model creation
    print("Building the model...")
    if config_args.quantize == True:
        print('Quantized model created')
        # Quantized model creation
        activation_quantizer = linear_mid_tread_half_quantizer
        activation_quantizer_kwargs = {'bit': 2, 'max_value': 2}
        weight_quantizer = binary_mean_scaling_quantizer
        weight_quantizer_kwargs = {}
        model = MobileNetQuantize(
            config_args,
            activation_quantizer=activation_quantizer,
            activation_quantizer_kwargs=activation_quantizer_kwargs,
            weight_quantizer=weight_quantizer,
            weight_quantizer_kwargs=weight_quantizer_kwargs)
    else:
        print('Full precision model created')
        model = MobileNet(config_args)
    print("Model is built successfully\n\n")

    # Summarizer creation
    summarizer = Summarizer(sess, config_args.summary_dir)
    # Train class
    trainer = Train(sess, model, data, summarizer)

    if config_args.to_train:
        try:
            print("Training...")
            trainer.train()
            print("Training Finished\n\n")
        except KeyboardInterrupt:
            trainer.save_model()

    if config_args.to_test:
        print("Final test!")
        trainer.test('val')
        print("Testing Finished\n\n")
def main():
	# define empty list to store the losses and accuracy for ploting
	train_all_losses2 = []
	train_all_acc2 = []
	val_all_losses2 = []
	val_all_acc2 = []
	test_all_losses2 = 0.0
	# define the training epoches
	epochs = 100

	# instantiate Net class
	mobilenet = MobileNet()
	# use cuda to train the network
	mobilenet.to('cuda')
	#loss function and optimizer
	criterion = nn.BCELoss()
	learning_rate = 1e-3
	optimizer = torch.optim.Adam(mobilenet.parameters(), lr=learning_rate, betas=(0.9, 0.999))

	%load_ext memory_profiler

	best_acc = 0.0

	for epoch in range(epochs):
	    train(mobilenet, epoch, train_all_losses2, train_all_acc2)
	    acc = validation(mobilenet, val_all_losses2, val_all_acc2, best_acc)
	    # record the best model
	    if acc > best_acc:
	      checkpoint_path = './model_checkpoint.pth'
	      best_acc = acc
	      # save the model and optimizer
	      torch.save({'model_state_dict': mobilenet.state_dict(),
	              'optimizer_state_dict': optimizer.state_dict()}, checkpoint_path)
	      print('new best model saved')
	    print("========================================================================")

	checkpoint_path = './model_checkpoint.pth'
	model = MobileNet().to('cuda')
	checkpoint = torch.load(checkpoint_path)
	print("model load successfully.")

	model.load_state_dict(checkpoint['model_state_dict'])
	optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
	model.eval()
	attr_acc = []
	test(model, attr_acc=attr_acc)

	# plot results
	plt.figure(figsize=(8, 10))
	plt.barh(range(40), [100 * acc for acc in attr_acc], tick_label = attributes, fc = 'brown')
	plt.show()

	plt.figure(figsize=(8, 6))
	plt.xlabel('Epochs')
	plt.ylabel('Loss')
	plt.title('Loss')
	plt.grid(True, linestyle='-.')
	plt.plot(train_all_losses2, c='salmon', label = 'Training Loss')
	plt.plot(val_all_losses2, c='brown', label = 'Validation Loss')
	plt.legend(fontsize='12', loc='upper right')
	plt.show()

	plt.figure(figsize=(8, 6))
	plt.xlabel('Epochs')
	plt.ylabel('Accuracy')
	plt.title('Accuracy')
	plt.grid(True, linestyle='-.')
	plt.plot(train_all_acc2, c='salmon', label = 'Training Accuracy')
	plt.plot(val_all_acc2, c='brown', label = 'Validation Accuracy')
	plt.legend(fontsize='12', loc='lower right')
	plt.show()