Пример #1
0
    def __init__(self, options):
        self.device = options.device
        self.enc_path = options.enc_path
        self.content_size = options.content_size
        self.enc_iter = options.enc_iter

        if not os.path.exists(os.path.join(self.enc_path, 'codes')):
            os.makedirs(os.path.join(self.enc_path, 'codes'))

        transforms = []
        if options.crop_size is not None:
            transforms.append(T.CenterCrop(options.crop_size))
        transforms.append(T.Resize(options.image_size))
        transforms.append(T.ToTensor())
        transforms.append(T.Normalize((0.5, 0.5, 0.5, 0), (0.5, 0.5, 0.5, 1)))

        self.dataset = ImageFolder(options.data_root,
                                   transform=T.Compose(transforms))
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=options.batch_size,
            num_workers=options.nloader)
        self.data_iter = iter(self.dataloader)

        self.enc = models.Encoder(options.image_size, options.image_size,
                                  options.enc_features, options.enc_blocks,
                                  options.enc_adain_features,
                                  options.enc_adain_blocks,
                                  options.content_size)
        self.enc.to(self.device)
        self.enc.load_state_dict(
            torch.load(os.path.join(self.enc_path, 'models',
                                    '{0}_enc.pt'.format(self.enc_iter)),
                       map_location=self.device))
Пример #2
0
def main(args):
    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('device: {}'.format(args.device))

    # create model
    model = ConvNet(cfg.NUM_CLASSES).to(args.device)
    
    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss().to(args.device)

    # load checkpoint
    if args.model_weight:
        if os.path.isfile(args.model_weight):
            print("=> loading checkpoint '{}'".format(args.model_weight))
            
            checkpoint = torch.load(args.model_weight, map_location=args.device)
            model.load_state_dict(checkpoint['state_dict'])

            print("=> loaded checkpoint '{}'".format(args.model_weight))
        else:
            print("=> no checkpoint found at '{}'".format(args.model_weight))

    # Data loading code
    test_dataset = ImageFolder(cfg.TEST_PATH, mode='test')

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    
    # Evaluate on test dataset
    validate(test_loader, model, criterion, args)
    def __init__(self, options):
        super(ClassifierTrainer, self).__init__(options, copy_keys=copy_keys)

        transforms = []
        if options.crop_size is not None:
            transforms.append(T.CenterCrop(options.crop_size))
        transforms.append(T.Resize(options.image_size))
        transforms.append(T.CenterCrop(options.image_size))
        transforms.append(T.ToTensor())

        image_transforms = transforms + [
            T.Normalize((0.5, 0.5, 0.5, 0), (0.5, 0.5, 0.5, 1))
        ]
        self.dataset = ImageFolder(options.data_root,
                                   transform=T.Compose(image_transforms))
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=options.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=options.nloader)
        self.data_iter = iter(self.dataloader)

        self.cla = models.ClassifierOrDiscriminator(
            options.image_size, options.image_size, options.cla_features,
            options.cla_blocks, options.cla_adain_features,
            options.cla_adain_blocks, self.nclass)
        self.cla.to(self.device)
        self.cla_optim = optim.Adam(self.cla.parameters(),
                                    lr=self.lr,
                                    eps=1e-4)
        self.add_model('cla', self.cla, self.cla_optim)

        if self.load_path is not None:
            self.load(options.load_iter)
Пример #4
0
def detect_img(img_path,
               model,
               device,
               detect_size,
               class_names,
               conf_thres,
               nms_thres,
               output_dir,
               save_plot=True):
    """
    Using the given trained model to detect veins on the image (specified by img_path).
        - model: yolo loaded with checkpoint weights
        - img_path: absolute path to the image
        - detect_size: input image size of the model
        - class_names: a list of target class names
    """
    print(f'\nPerforming object detection on ---> {img_path} \t')
    FloatTensor = torch.cuda.FloatTensor if device.type == 'cuda' else torch.FloatTensor
    img_np = np.array(Image.open(img_path))
    img = ImageFolder(folder_path='',
                      img_size=detect_size).preprocess(img_path)
    img = img.unsqueeze(0).type(FloatTensor)

    begin_time = time.time()
    with torch.no_grad():
        outputs = model(img)
        detections_list = non_max_suppression(outputs, conf_thres, nms_thres)
    end_time = time.time()
    # print(f'detections: {detections}')

    inference_time = end_time - begin_time
    print(f'inference_time: {inference_time}s')

    detections_rescaled = None
    if detections_list[0] is not None:
        # it is a list due to the implementation of non_max_suppression that deal with batch samples
        detections_rescaled = rescale_boxes(detections_list[0].clone(),
                                            detect_size, img_np.shape[:2])
        plot_img_and_bbox(img_path, detections_rescaled, detect_size,
                          class_names, output_dir, save_plot)
    else:
        print('No veins detected on this image!')

    return detections_list, detections_rescaled, inference_time
Пример #5
0
def process_options(options):
	if options.cmd in ['encode', 'stage2']:
		if options.enc_path is None:
			raise ValueError('encoder path must be specified')
	if options.cmd in ['stage1', 'classifier', 'stage2', 'gan']:
		options.save_path = options.save_path or options.load_path
		if options.save_path is None:
			raise ValueError('save path must be specified')

	if options.augment is not None:
		options.augment = options.augment == 'true'
	if options.cla_fake is not None:
		options.cla_fake = options.cla_fake == 'true'

	if options.load_path is not None:
		with open(os.path.join(options.load_path, 'options')) as file:
			saved_options = json.load(file)
		for key in load_keys:
			options.__dict__[key] = saved_options[key]
		for key in override_keys:
			if options.__dict__[key] is None:
				options.__dict__[key] = saved_options[key]
		options.load_iter = options.load_iter or 'last'
	else:
		if options.cmd in ['encode', 'stage2']:
			with open(os.path.join(options.enc_path, 'options')) as file:
				enc_options = json.load(file)

			print('using encoder structure from stage 1')
			for key in enc_keys + ['data_root', 'image_size', 'content_size']:
				options.__dict__[key] = enc_options[key]

			if options.cmd == 'stage2':
				if not options.reset_gen:
					print('using generator structure from stage 1')
					for key in gen_keys:
						options.__dict__[key] = enc_options[key]

				if not (options.reset_gen and options.reset_sty):
					print('using length of style code from stage 1')
					options.style_size = enc_options['style_size']

		if options.cmd == 'stage2' and options.cla_path is not None:
			with open(os.path.join(options.cla_path, 'options')) as file:
				cla_options = json.load(file)

			if cla_options['image_size'] != options.image_size:
				raise ValueError('image size of stage 1 networks and classifier 2 does not match')

			if not options.reset_dis:
				print('using discriminator structure from pre-trained classifier')
				for key1, key2 in zip(dis_keys, cla_keys):
					options.__dict__[key1] = cla_options[key2]

			if not options.reset_cla:
				print('using classifier structure from pre-trained classifier')
				for key in cla_keys:
					options.__dict__[key] = cla_options[key]

		for key, value in stage_defaults[options.cmd].items():
			if options.__dict__[key] is None:
				options.__dict__[key] = value
		for key, value in defaults.items():
			if options.__dict__[key] is None:
				options.__dict__[key] = value

		dataset = ImageFolder(options.data_root, transform = T.ToTensor())
		if options.image_size is None:
			if options.crop_size is None:
				print('image size not specified, using image size of dataset')
				options.image_size = dataset[0][0].size(1)
			else:
				print('image size not specified, using crop size')
				options.image_size = options.crop_size

		options.nclass = dataset.get_nclass()
		if options.style_size is None:
			print('style size not specified, using defaults')
			options.style_size = min(max(min(options.nclass // 4, 256), 16), options.nclass)

		if options.image_size % 2 != 0:
			raise ValueError('image size must be an even integer')

		options.vis_row = options.visualize_size[0]
		options.vis_col = options.visualize_size[1]

		for min_size, max_size, features, content_size, batch_size, vis_col in default_sizes:
			if min_size <= options.image_size <= max_size:
				options.features = options.features or features
				options.mlp_features = options.mlp_features or features[-1]
				if options.blocks is None:
					options.blocks = (['cc'] + ['cbc'] * (len(options.features) - 2) + ['f'])
					options.gen_blocks = options.gen_blocks or (['tc'] + ['cbc'] * (len(options.features) - 2) + ['f'])
				options.content_size = options.content_size or content_size
				options.batch_size = options.batch_size or (batch_size * 2 if options.cmd in ['classifier', 'gan', 'encode'] else batch_size)
				options.vis_col = options.vis_col or (vis_col * 2 if options.cmd in ['gan'] else vis_col)
		if options.content_size is None:
			raise ValueError('content size not specified and failed to set defaults')
		if options.batch_size is None:
			raise ValueError('batch size not specified and failed to set defaults')
		options.vis_col = options.vis_col or (10 if options.image_size < 16 else 2)
		options.vis_row = options.vis_row or (options.vis_col if options.cmd in ['gan'] else options.vis_col * 2)

		options.content_dropout = options.content_dropout or (1 - 2 / options.content_size)
		options.style_dropout = options.style_dropout or (1 - 2 / options.style_size)

		if options.cmd == 'stage1':
			for key1, key2 in zip(enc_keys, net_keys):
				if options.__dict__[key1] is None:
					options.__dict__[key1] = options.__dict__[key2]
				if options.__dict__[key1] is None:
					raise ValueError('encoder structure incomplete and failed to set defaults')

		if options.cmd in ['stage1', 'stage2', 'gan']:
			for key1, key2 in zip(gen_keys, net_keys):
				if options.__dict__[key1] is None:
					options.__dict__[key1] = options.__dict__[key2]
				if options.__dict__[key1] is None:
					raise ValueError('generator structure incomplete and failed to set defaults')

		if options.cmd == 'classifier' or (options.cmd == 'stage1' and not options.mlp):
			for key1, key2 in zip(cla_keys, net_keys):
				if options.__dict__[key1] is None:
					options.__dict__[key1] = options.__dict__[key2]
				if options.__dict__[key1] is None:
					raise ValueError('classifier structure incomplete and failed to set defaults')

		if options.cmd in ['stage2', 'gan']:
			for key1, key2 in zip(dis_keys, net_keys):
				if options.__dict__[key1] is None:
					options.__dict__[key1] = options.__dict__[key2]
				if options.__dict__[key1] is None:
					raise ValueError('discriminator structure incomplete and failed to set defaults')

	if options.cmd in ['stage1', 'classifier', 'stage2', 'gan']:
		save_options = {}
		for key in load_keys + override_keys:
			save_options[key] = options.__dict__[key]
		if not os.path.exists(options.save_path):
			os.makedirs(options.save_path)
		with open(os.path.join(options.save_path, 'options'), 'w') as file:
			json.dump(save_options, file)

		options.augment_options = types.SimpleNamespace()
		for key in augment_keys:
			options.augment_options.__dict__[key] = options.__dict__[key]

	return options
Пример #6
0
    def __init__(self, options):
        super(GANTrainer, self).__init__(options,
                                         subfolders=['samples'],
                                         copy_keys=copy_keys)

        transforms = []
        if options.crop_size is not None:
            transforms.append(T.CenterCrop(options.crop_size))
        transforms.append(T.Resize(options.image_size))
        transforms.append(T.ToTensor())
        transforms.append(T.Normalize((0.5, 0.5, 0.5, 0), (0.5, 0.5, 0.5, 1)))

        self.dataset = ImageFolder(options.data_root,
                                   transform=T.Compose(transforms))
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=options.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=options.nloader)
        self.data_iter = iter(self.dataloader)

        self.gen = models.Generator(options.image_size, options.image_size,
                                    options.gen_features, options.gen_blocks,
                                    options.gen_adain_features,
                                    options.gen_adain_blocks,
                                    options.content_size)
        self.gen.to(self.device)
        self.gen_optim = optim.RMSprop(self.gen.parameters(),
                                       lr=self.lr,
                                       eps=1e-4,
                                       alpha=0.9)
        self.add_model('gen', self.gen, self.gen_optim)

        self.dis = models.ClassifierOrDiscriminator(options.image_size,
                                                    options.image_size,
                                                    options.dis_features,
                                                    options.dis_blocks,
                                                    options.dis_adain_features,
                                                    options.dis_adain_blocks)
        self.dis.to(self.device)
        self.dis_optim = optim.RMSprop(self.dis.parameters(),
                                       lr=self.lr,
                                       eps=1e-4,
                                       alpha=0.9)
        self.add_model('dis', self.dis, self.dis_optim)

        if self.load_path is not None:
            self.vis_codes = torch.load(os.path.join(self.load_path, 'samples',
                                                     'codes.pt'),
                                        map_location=self.device)
            self.load(options.load_iter)
        else:
            self.vis_codes = gaussian_noise(options.vis_row * options.vis_col,
                                            options.content_size).to(
                                                self.device)
            self.state.dis_total_batches = 0

        if self.save_path != self.load_path:
            torch.save(self.vis_codes,
                       os.path.join(self.save_path, 'samples', 'codes.pt'))

        self.add_periodic_func(self.visualize_fixed, options.visualize_iter)
        self.visualize_fixed()

        self.loss_avg_factor = 0.9
Пример #7
0
    'crop_size': 380
}

joint_transform = joint_transforms.Compose([
    joint_transforms.RandomCrop(args['crop_size']),
    joint_transforms.RandomHorizontallyFlip(),
    joint_transforms.RandomRotate(10),
])

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
target_transform = transforms.ToTensor()

train_set = ImageFolder(duts_train_path, joint_transform, img_transform,
                        target_transform)
train_loader = DataLoader(train_set,
                          batch_size=args['train_batch_size'],
                          num_workers=12,
                          shuffle=True,
                          drop_last=True)

criterionBCE = nn.BCELoss().cuda()


def main():
    exp_name = 'dpnet'
    train(exp_name)


def train(exp_name):
Пример #8
0
from datasets import ImageFolder
import torchvision.transforms as transforms
from PIL import Image
from itertools import izip

rawroot = '/mnt/Data1/Water_Real'
outroot = './results'
outname = 'concat'

datasets = []

# input images
datasets.append(
    ImageFolder(rawroot,
                transform=transforms.Compose(
                    [transforms.Resize(256),
                     transforms.CenterCrop(256)]),
                return_path=True))

# results images
for exp_name in [
        'warp_L1', 'warp_L1VGG', 'color_L1VGG', 'color_L1VGGAdv',
        'both_L1VGGAdv'
]:
    datasets.append(
        ImageFolder(os.path.join(outroot, '%s_test' % exp_name),
                    return_path=True))

# concat and save each image
for i, imgs in enumerate(izip(*datasets)):
    name = imgs[0][-1]
Пример #9
0
    'momentum': 0.9,
    'snapshot': ''
}

joint_transform = joint_transforms.Compose([
    joint_transforms.RandomCrop(300),
    joint_transforms.RandomHorizontallyFlip(),
    joint_transforms.RandomRotate(10)
])
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
target_transform = transforms.ToTensor()

train_set = ImageFolder(msra10k_path, joint_transform, img_transform,
                        target_transform)
train_loader = DataLoader(train_set,
                          batch_size=args['train_batch_size'],
                          num_workers=12,
                          shuffle=True)

criterion = nn.BCEWithLogitsLoss().cuda()
log_path = os.path.join(ckpt_path, exp_name,
                        str(datetime.datetime.now()) + '.txt')


def main():
    net = R3Net().cuda().train()

    optimizer = optim.SGD([{
        'params': [
Пример #10
0
    joint_transforms.RandomCrop(448),
    joint_transforms.RandomHorizontallyFlip(),
    joint_transforms.RandomRotate(10)
])
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
depth_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
target_transform = transforms.ToTensor()
to_pil = transforms.ToPILImage()

train_set = ImageFolder(dutsk_path, args['status'],joint_transform, img_transform, target_transform,depth_transform)
train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=12, shuffle=True)

criterion = nn.BCEWithLogitsLoss().cuda()
log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt')


def main():
    net = R3Net(num_class=1).cuda().train()
    # net = nn.DataParallel(net,device_ids=[0,1])

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
Пример #11
0
def select_worst_images(args, model, full_train_loader, device):
    print("Selecting images for next epoch training...")
    model.eval()

    gts = []
    paths = []
    losses = []

    micro_batch_size = args.batch_size // args.batch_split
    precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset)

    if args.input_channels == 3:
        train_tx = tv.transforms.Compose([
            tv.transforms.Resize((precrop, precrop)),
            tv.transforms.RandomCrop((crop, crop)),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    elif args.input_channels == 2:
        train_tx = tv.transforms.Compose([
            tv.transforms.Resize((precrop, precrop)),
            tv.transforms.RandomCrop((crop, crop)),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5), (0.5, 0.5)),
        ])

    elif args.input_channels == 1:
        train_tx = tv.transforms.Compose([
            tv.transforms.Resize((precrop, precrop)),
            tv.transforms.RandomCrop((crop, crop)),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5), (0.5)),
        ])

    pbar = enumerate(full_train_loader)
    pbar = tqdm.tqdm(pbar, total=len(full_train_loader))

    for b, (path, x, y) in pbar:
        with torch.no_grad():
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            # compute output, measure accuracy and record loss.
            logits = model(x)

            paths.extend(path)
            gts.extend(y.cpu().numpy())

            c = torch.nn.CrossEntropyLoss(reduction='none')(logits, y)

            losses.extend(
                c.cpu().numpy().tolist())  # Also ensures a sync point.

        # measure elapsed time
        end = time.time()

    gts = np.array(gts)
    losses = np.array(losses)
    losses[np.argsort(losses)[int(losses.shape[0] *
                                  (1.0 - args.noise)):]] = 0.0  #

    #paths_ = np.array(paths)[np.where(losses > np.median(losses))[0]]
    #gts_   = gts[np.where(losses > np.median(losses))[0]]

    selection_idx = int(args.data_fraction * losses.shape[0])
    paths_ = np.array(paths)[np.argsort(losses)[-selection_idx:]]
    gts_ = gts[np.argsort(losses)[-selection_idx:]]

    smart_train_set = ImageFolder(paths_, gts_, train_tx, crop)

    smart_train_loader = torch.utils.data.DataLoader(
        smart_train_set,
        batch_size=micro_batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=False)

    return smart_train_set, smart_train_loader
Пример #12
0
    plt.show()

    x = SynthData(256, n=1)
    """
    for i in range(100):
        x.step(steps=3)
        plt.imshow(x.u[0].numpy())
        plt.title('%d %f %f %f'%(i,x.u[0].min(), x.u[0].mean(), x.u[0].max()))
        plt.draw()
        plt.pause(.1)


    """
    from datasets import ImageFolder
    import torchvision.transforms as transforms

    img_dir = '/mnt/Data1/ImageNet/val'
    data = ImageFolder(img_dir,
                       transform=transforms.Compose([
                           transforms.Resize(256),
                           transforms.CenterCrop(256),
                           transforms.ToTensor(),
                       ]))
    img = data[10][0]

    for i in range(100):
        img1 = x(img)[0]
        plt.imshow(img1.permute(1, 2, 0))
        plt.draw()
        plt.pause(.1)
Пример #13
0
 image_transform = transforms.Compose([
     transforms.Resize(int(imsize * 76 / 64)),
     transforms.RandomCrop(imsize),
     transforms.RandomHorizontalFlip()
 ])
 if cfg.DATA_DIR.find('lsun') != -1:
     from datasets import LSUNClass
     dataset = LSUNClass('%s/%s_%s_lmdb' %
                         (cfg.DATA_DIR, cfg.DATASET_NAME, split_dir),
                         base_size=cfg.TREE.BASE_SIZE,
                         transform=image_transform)
 elif cfg.DATA_DIR.find('imagenet') != -1:
     from datasets import ImageFolder
     dataset = ImageFolder(cfg.DATA_DIR,
                           split_dir='train',
                           custom_classes=CLASS_DIC[cfg.DATASET_NAME],
                           base_size=cfg.TREE.BASE_SIZE,
                           transform=image_transform)
 elif cfg.DATA_DIR.find('xray_ct_scan') != -1:
     from datasets import ImageFolder
     dataset = ImageFolder(
         cfg.DATA_DIR,
         split_dir='train',
         custom_classes=["Covid-19", "No_findings", "Pneumonia"],
         base_size=cfg.TREE.BASE_SIZE,
         transform=image_transform)
 elif cfg.GAN.B_CONDITION:  # text to image task
     from datasets import TextDataset
     dataset = TextDataset(cfg.DATA_DIR,
                           split_dir,
                           base_size=cfg.TREE.BASE_SIZE,
Пример #14
0
def main(args):
    best_acc1 = 0
    os.makedirs('checkpoints', exist_ok=True)

    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('device: {}'.format(args.device))

    # create model
    model = ConvNet(cfg.NUM_CLASSES).to(args.device)
    #model.apply(weights_init_normal)
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(args.device)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=args.device)

            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])

            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True
    # Data loading code
    train_dataset = ImageFolder(cfg.TRAIN_PATH)
    val_dataset = ImageFolder(cfg.VAL_PATH)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    logger = Logger('./logs')
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        adjust_learning_rate(optimizer, epoch, args)
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, args)

        # evaluate on validation set
        val_loss, val_acc = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = val_acc > best_acc1
        best_acc1 = max(val_acc, best_acc1)

        # log
        info = {
            'train_loss': float(train_loss),
            'train_acc': float(train_acc),
            'val_loss': float(val_loss),
            'val_acc': float(val_acc)
        }
        for tag, value in info.items():
            logger.scalar_summary(tag, value, epoch)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
Пример #15
0
    def __init__(self, options):
        super(Stage1Trainer, self).__init__(options,
                                            subfolders=['reconstructions'],
                                            copy_keys=copy_keys)

        transforms = []
        if options.crop_size is not None:
            transforms.append(T.CenterCrop(options.crop_size))
        transforms.append(T.Resize(options.image_size))
        transforms.append(T.ToTensor())

        image_transforms = transforms + [
            T.Normalize((0.5, 0.5, 0.5, 0), (0.5, 0.5, 0.5, 1))
        ]
        image_set = ImageFolder(options.data_root,
                                transform=T.Compose(image_transforms))

        if options.weight_root is not None:
            self.has_weight = True
            weight_transforms = transforms + [lambda x: x[0]]
            weight_set = ImageFolder(options.weight_root,
                                     transform=T.Compose(weight_transforms))
            self.dataset = ParallelDataset(image_set, weight_set)
        else:
            self.has_weight = False
            self.dataset = image_set

        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=options.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=options.nloader)
        self.data_iter = iter(self.dataloader)

        self.enc = models.Encoder(options.image_size, options.image_size,
                                  options.enc_features, options.enc_blocks,
                                  options.enc_adain_features,
                                  options.enc_adain_blocks,
                                  options.content_size)
        self.enc.to(self.device)
        self.enc_optim = optim.Adam(self.enc.parameters(),
                                    lr=self.lr,
                                    eps=1e-4)
        self.add_model('enc', self.enc, self.enc_optim)

        self.gen = models.TwoPartNestedDropoutGenerator(
            options.image_size, options.image_size, options.gen_features,
            options.gen_blocks, options.gen_adain_features,
            options.gen_adain_blocks, options.content_size, options.style_size)
        self.gen.to(self.device)
        self.gen_optim = optim.Adam(self.gen.parameters(),
                                    lr=self.lr,
                                    eps=1e-4)
        self.add_model('gen', self.gen, self.gen_optim)

        if self.mlp:
            self.cla = models.MLPClassifier(options.content_size,
                                            options.mlp_features,
                                            options.mlp_layers, self.nclass)
        else:
            self.cla = models.ClassifierOrDiscriminator(
                options.image_size, options.image_size, options.cla_features,
                options.cla_blocks, options.cla_adain_features,
                options.cla_adain_blocks, self.nclass)
        self.cla.to(self.device)
        self.cla_optim = optim.Adam(self.cla.parameters(),
                                    lr=self.lr,
                                    eps=1e-4)
        self.add_model('cla', self.cla, self.cla_optim)

        self.sty = models.NormalizedStyleBank(self.nclass, options.style_size,
                                              image_set.get_class_freq())
        self.sty.to(self.device)
        self.sty_optim = optim.Adam(self.sty.parameters(),
                                    lr=self.sty_lr,
                                    eps=1e-8)
        self.add_model('sty', self.sty, self.sty_optim)

        if self.load_path is not None:
            self.vis_images = torch.load(os.path.join(self.load_path,
                                                      'reconstructions',
                                                      'images.pt'),
                                         map_location=self.device)
            self.vis_labels = torch.load(os.path.join(self.load_path,
                                                      'reconstructions',
                                                      'labels.pt'),
                                         map_location=self.device)
            if self.has_weight:
                self.vis_weights = torch.load(os.path.join(
                    self.load_path, 'reconstructions', 'weights.pt'),
                                              map_location=self.device)
            self.load(options.load_iter)
        else:
            vis_images = []
            vis_labels = []
            if self.has_weight:
                vis_weights = []
            vis_index = random.sample(range(len(image_set)),
                                      options.vis_row * options.vis_col)
            for k in vis_index:
                image, label = image_set[k]
                vis_images.append(image)
                vis_labels.append(label)
                if self.has_weight:
                    weight, _ = weight_set[k]
                    vis_weights.append(weight)
            self.vis_images = torch.stack(vis_images, dim=0).to(self.device)
            self.vis_labels = one_hot(
                torch.tensor(vis_labels, dtype=torch.int32),
                self.nclass).to(self.device)
            if self.has_weight:
                self.vis_weights = torch.stack(vis_weights,
                                               dim=0).to(self.device)

        if self.save_path != self.load_path:
            torch.save(
                self.vis_images,
                os.path.join(self.save_path, 'reconstructions', 'images.pt'))
            torch.save(
                self.vis_labels,
                os.path.join(self.save_path, 'reconstructions', 'labels.pt'))
            save_image(
                self.vis_images.add(1).div(2),
                os.path.join(self.save_path, 'reconstructions', 'target.png'),
                self.vis_col)
            if self.has_weight:
                torch.save(
                    self.vis_weights,
                    os.path.join(self.save_path, 'reconstructions',
                                 'weights.pt'))
                save_image(
                    self.vis_weights.unsqueeze(1),
                    os.path.join(self.save_path, 'reconstructions',
                                 'weight.png'), self.vis_col)

        self.add_periodic_func(self.visualize_fixed, options.visualize_iter)
        self.visualize_fixed()

        self.con_drop_prob = torch.Tensor(options.content_size)
        for i in range(options.content_size):
            self.con_drop_prob[i] = options.content_dropout**i
        self.sty_drop_prob = torch.Tensor(options.style_size)
        for i in range(options.style_size):
            self.sty_drop_prob[i] = options.style_dropout**i
Пример #16
0
    'snapshot': '',
    'epoch': 60
}

joint_transform = joint_transforms.Compose([
    joint_transforms.RandomCrop(300),
    joint_transforms.RandomHorizontallyFlip(),
    joint_transforms.RandomRotate(10)
])
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
target_transform = transforms.ToTensor()

train_set = ImageFolder(kaist_path, joint_transform, img_transform,
                        target_transform)
train_loader = DataLoader(train_set,
                          batch_size=args['train_batch_size'],
                          num_workers=6,
                          shuffle=True)

criterion = nn.BCEWithLogitsLoss().cuda()
log_path = os.path.join(ckpt_path, exp_name,
                        str(datetime.datetime.now()) + '.txt')


def main():
    net = R3Net().cuda().train()

    optimizer = optim.SGD([{
        'params': [
Пример #17
0
def mktrainval(args, logger):
    """Returns train and validation datasets."""
    precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset)

    if args.input_channels == 3:
        train_tx = tv.transforms.Compose([
            tv.transforms.Resize((precrop, precrop)),
            tv.transforms.RandomCrop((crop, crop)),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        val_tx = tv.transforms.Compose([
            tv.transforms.Resize((crop, crop)),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    elif args.input_channels == 2:
        train_tx = tv.transforms.Compose([
            tv.transforms.Resize((precrop, precrop)),
            tv.transforms.RandomCrop((crop, crop)),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5), (0.5, 0.5)),
        ])

        val_tx = tv.transforms.Compose([
            tv.transforms.Resize((crop, crop)),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5), (0.5, 0.5)),
        ])

    elif args.input_channels == 1:
        train_tx = tv.transforms.Compose([
            tv.transforms.Resize((precrop, precrop)),
            tv.transforms.RandomCrop((crop, crop)),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5), (0.5)),
        ])

        val_tx = tv.transforms.Compose([
            tv.transforms.Resize((crop, crop)),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5), (0.5)),
        ])

    if args.dataset == "cifar10":
        train_set = tv.datasets.CIFAR10(args.datadir,
                                        transform=train_tx,
                                        train=True,
                                        download=True)
        valid_set = tv.datasets.CIFAR10(args.datadir,
                                        transform=val_tx,
                                        train=False,
                                        download=True)
    elif args.dataset == "cifar100":
        train_set = tv.datasets.CIFAR100(args.datadir,
                                         transform=train_tx,
                                         train=True,
                                         download=True)
        valid_set = tv.datasets.CIFAR100(args.datadir,
                                         transform=val_tx,
                                         train=False,
                                         download=True)

    elif args.dataset == "imagenet2012":

        folder_path = pjoin(args.datadir, "train")
        files = sorted(glob.glob("%s/*/*.*" % folder_path))
        #labels = [int(file.split("/")[-2]) for file in files]
        labels = [class_dict[file.split("/")[-2]] for file in files]
        train_set = ImageFolder(files, labels, train_tx, crop)

        folder_path = pjoin(args.datadir, "val")
        files = sorted(glob.glob("%s/*/*.*" % folder_path))
        #labels = [int(file.split("/")[-2]) for file in files]
        labels = [class_dict[file.split("/")[-2]] for file in files]
        valid_set = ImageFolder(files, labels, val_tx, crop)
        #train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"), train_tx)
        #valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx)
    else:
        raise ValueError(f"Sorry, we have not spent time implementing the "
                         f"{args.dataset} dataset in the PyTorch codebase. "
                         f"In principle, it should be easy to add :)")

    if args.examples_per_class is not None:
        logger.info(
            f"Looking for {args.examples_per_class} images per class...")
        indices = fs.find_fewshot_indices(train_set, args.examples_per_class)
        train_set = torch.utils.data.Subset(train_set, indices=indices)

    logger.info(f"Using a training set with {len(train_set)} images.")
    logger.info(f"Using a validation set with {len(valid_set)} images.")

    micro_batch_size = args.batch_size // args.batch_split
    micro_batch_size_val = 4 * micro_batch_size

    valid_loader = torch.utils.data.DataLoader(valid_set,
                                               batch_size=micro_batch_size_val,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=False)

    if micro_batch_size <= len(train_set):
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=micro_batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=False)
        train_loader_val = torch.utils.data.DataLoader(
            train_set,
            batch_size=micro_batch_size_val,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=False)
    else:
        # In the few-shot cases, the total dataset size might be smaller than the batch-size.
        # In these cases, the default sampler doesn't repeat, so we need to make it do that
        # if we want to match the behaviour from the paper.
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=micro_batch_size,
            num_workers=args.workers,
            pin_memory=True,
            sampler=torch.utils.data.RandomSampler(
                train_set, replacement=True, num_samples=micro_batch_size))

    return train_set, valid_set, train_loader, valid_loader, train_loader_val
Пример #18
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataroot',
                        default='/mnt/Data1',
                        help='path to images')
    parser.add_argument('--workers',
                        default=4,
                        type=int,
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--batch-size',
                        default=16,
                        type=int,
                        help='mini-batch size')
    parser.add_argument('--outroot',
                        default='./results',
                        help='path to save the results')
    parser.add_argument('--exp-name',
                        default='test',
                        help='name of expirement')
    parser.add_argument('--load',
                        default='',
                        help='name of pth to load weights from')
    parser.add_argument('--freeze-cc-net',
                        dest='freeze_cc_net',
                        action='store_true',
                        help='dont train the color corrector net')
    parser.add_argument('--freeze-warp-net',
                        dest='freeze_warp_net',
                        action='store_true',
                        help='dont train the warp net')
    parser.add_argument('--test',
                        dest='test',
                        action='store_true',
                        help='only test the network')
    parser.add_argument(
        '--synth-data',
        dest='synth_data',
        action='store_true',
        help='use synthetic data instead of tank data for training')
    parser.add_argument('--epochs',
                        default=3,
                        type=int,
                        help='number of epochs to train for')
    parser.add_argument('--no-warp-net',
                        dest='warp_net',
                        action='store_false',
                        help='do not include warp net in the model')
    parser.add_argument('--warp-net-downsample',
                        default=3,
                        type=int,
                        help='number of downsampling layers in warp net')
    parser.add_argument('--no-color-net',
                        dest='color_net',
                        action='store_false',
                        help='do not include color net in the model')
    parser.add_argument('--color-net-downsample',
                        default=3,
                        type=int,
                        help='number of downsampling layers in color net')
    parser.add_argument(
        '--no-color-net-skip',
        dest='color_net_skip',
        action='store_false',
        help='dont use u-net skip connections in the color net')
    parser.add_argument(
        '--dim',
        default=32,
        type=int,
        help='initial feature dimension (doubled at each downsampling layer)')
    parser.add_argument('--n-res',
                        default=8,
                        type=int,
                        help='number of residual blocks')
    parser.add_argument('--norm',
                        default='gn',
                        type=str,
                        help='type of normalization layer')
    parser.add_argument(
        '--denormalize',
        dest='denormalize',
        action='store_true',
        help='denormalize output image by input image mean/var')
    parser.add_argument(
        '--weight-X-L1',
        default=1.,
        type=float,
        help='weight of L1 reconstruction loss after color corrector net')
    parser.add_argument('--weight-Y-L1',
                        default=1.,
                        type=float,
                        help='weight of L1 reconstruction loss after warp net')
    parser.add_argument('--weight-Y-VGG',
                        default=1.,
                        type=float,
                        help='weight of perceptual loss after warp net')
    parser.add_argument(
        '--weight-Z-L1',
        default=1.,
        type=float,
        help='weight of L1 reconstruction loss after color net')
    parser.add_argument('--weight-Z-VGG',
                        default=.5,
                        type=float,
                        help='weight of perceptual loss after color net')
    parser.add_argument('--weight-Z-Adv',
                        default=0.2,
                        type=float,
                        help='weight of adversarial loss after color net')
    args = parser.parse_args()

    # set random seed for consistent fixed batch
    torch.manual_seed(8)

    # set weights of losses of intermediate outputs to zero if not necessary
    if not args.warp_net:
        args.weight_Y_L1 = 0
        args.weight_Y_VGG = 0
    if not args.color_net:
        args.weight_Z_L1 = 0
        args.weight_Z_VGG = 0
        args.weight_Z_Adv = 0

    # datasets
    train_dir_1 = os.path.join(args.dataroot, 'Water', 'train')
    train_dir_2 = os.path.join(args.dataroot, 'ImageNet', 'train')
    val_dir_1 = os.path.join(args.dataroot, 'Water', 'test')
    val_dir_2 = os.path.join(args.dataroot, 'ImageNet', 'test')
    test_dir = os.path.join(args.dataroot, 'Water_Real')

    if args.synth_data:
        train_data = ImageFolder(train_dir_2,
                                 transform=transforms.Compose([
                                     transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                          std=[0.5, 0.5, 0.5]),
                                     synthdata.SynthData(224,
                                                         n=args.batch_size),
                                 ]))
    else:
        train_data = PairedImageFolder(
            train_dir_1,
            train_dir_2,
            transform=transforms.Compose([
                pairedtransforms.RandomResizedCrop(224),
                pairedtransforms.RandomHorizontalFlip(),
                pairedtransforms.ToTensor(),
                pairedtransforms.Normalize(mean=[0.5, 0.5, 0.5],
                                           std=[0.5, 0.5, 0.5]),
            ]))
    val_data = PairedImageFolder(val_dir_1,
                                 val_dir_2,
                                 transform=transforms.Compose([
                                     pairedtransforms.Resize(256),
                                     pairedtransforms.CenterCrop(256),
                                     pairedtransforms.ToTensor(),
                                     pairedtransforms.Normalize(
                                         mean=[0.5, 0.5, 0.5],
                                         std=[0.5, 0.5, 0.5]),
                                 ]))
    test_data = ImageFolder(test_dir,
                            transform=transforms.Compose([
                                transforms.Resize(256),
                                transforms.CenterCrop(256),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                     std=[0.5, 0.5, 0.5]),
                            ]),
                            return_path=True)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              num_workers=args.workers,
                                              pin_memory=True,
                                              shuffle=False)

    # fixed test batch for visualization during training
    fixed_batch = iter(val_loader).next()[0]

    # model
    model = networks.Model(args)
    model.cuda()

    # load weights from checkpoint
    if args.test and not args.load:
        args.load = args.exp_name
    if args.load:
        model.load_state_dict(torch.load(
            os.path.join(args.outroot, '%s_net.pth' % args.load)),
                              strict=args.test)

    # create outroot if necessary
    if not os.path.exists(args.outroot):
        os.makedirs(args.outroot)

    # if args.test only run test script
    if args.test:
        test(test_loader, model, args)
        return

    # main training loop
    for epoch in range(args.epochs):
        train(train_loader, model, fixed_batch, epoch, args)
        torch.save(model.state_dict(),
                   os.path.join(args.outroot, '%s_net.pth' % args.exp_name))
        test(test_loader, model, args)
Пример #19
0
def main(args):
    writer = SummaryWriter("./logs/{0}".format(args.output_folder))
    save_filename = "./models/{0}".format(args.output_folder)

    if args.dataset in ["mnist", "fashion-mnist", "cifar10"]:
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        if args.dataset == "mnist":
            # Define the train & test datasets
            train_dataset = datasets.MNIST(
                args.data_folder, train=True, download=True, transform=transform
            )
            test_dataset = datasets.MNIST(
                args.data_folder, train=False, transform=transform
            )
            num_channels = 1
        elif args.dataset == "fashion-mnist":
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(
                args.data_folder, train=True, download=True, transform=transform
            )
            test_dataset = datasets.FashionMNIST(
                args.data_folder, train=False, transform=transform
            )
            num_channels = 1
        elif args.dataset == "cifar10":
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(
                args.data_folder, train=True, download=True, transform=transform
            )
            test_dataset = datasets.CIFAR10(
                args.data_folder, train=False, transform=transform
            )
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == "miniimagenet":
        transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(128),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(
            args.data_folder, train=True, download=True, transform=transform
        )
        valid_dataset = MiniImagenet(
            args.data_folder, valid=True, download=True, transform=transform
        )
        test_dataset = MiniImagenet(
            args.data_folder, test=True, download=True, transform=transform
        )
        num_channels = 3
    else:
        transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(args.image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        # Define the train, valid & test datasets
        train_dataset = ImageFolder(
            os.path.join(args.data_folder, "train"), transform=transform
        )
        valid_dataset = ImageFolder(
            os.path.join(args.data_folder, "val"), transform=transform
        )
        test_dataset = valid_dataset
        num_channels = 3

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    save_image(fixed_grid, "true.png")
    writer.add_image("original", fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
    save_image(grid, "rec.png")
    writer.add_image("reconstruction", grid, 0)

    best_loss = -1
    for epoch in range(args.num_epochs):
        train(train_loader, model, optimizer, args, writer)
        loss, _ = test(valid_loader, model, args, writer)
        print(epoch, "test loss: ", loss)
        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
        save_image(grid, "rec.png")

        writer.add_image("reconstruction", grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open("{0}/best.pt".format(save_filename), "wb") as f:
                torch.save(model.state_dict(), f)
        with open("{0}/model_{1}.pt".format(save_filename, epoch + 1), "wb") as f:
            torch.save(model.state_dict(), f)
Пример #20
0
def train():
    from benchmark import calc_fid, extract_feature_from_generator_fn, load_patched_inception_v3, real_image_loader, image_generator, image_generator_perm
    import lpips

    from config import IM_SIZE_GAN, BATCH_SIZE_GAN, NFC, NBR_CLS, DATALOADER_WORKERS, EPOCH_GAN, ITERATION_AE, GAN_CKECKPOINT
    from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, LOG_INTERVAL, SAVE_FOLDER, TRIAL_NAME, DATA_NAME, MULTI_GPU
    from config import FID_INTERVAL, FID_BATCH_NBR, PRETRAINED_AE_PATH
    from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3

    real_features = None
    inception = load_patched_inception_v3().cuda()
    inception.eval()

    percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)

    saved_image_folder = saved_model_folder = None
    log_file_path = None
    if saved_image_folder is None:
        saved_image_folder, saved_model_folder = make_folders(
            SAVE_FOLDER, 'GAN_' + TRIAL_NAME)
        log_file_path = saved_image_folder + '/../gan_log.txt'
        log_file = open(log_file_path, 'w')
        log_file.close()

    dataset = PairedMultiDataset(data_root_colorful,
                                 data_root_sketch_1,
                                 data_root_sketch_2,
                                 data_root_sketch_3,
                                 im_size=IM_SIZE_GAN,
                                 rand_crop=True)
    print('the dataset contains %d images.' % len(dataset))
    dataloader = iter(
        DataLoader(dataset,
                   BATCH_SIZE_GAN,
                   sampler=InfiniteSamplerWrapper(dataset),
                   num_workers=DATALOADER_WORKERS,
                   pin_memory=True))

    from datasets import ImageFolder
    from datasets import trans_maker_augment as trans_maker

    dataset_rgb = ImageFolder(data_root_colorful, trans_maker(512))
    dataset_skt = ImageFolder(data_root_sketch_3, trans_maker(512))

    net_ae = AE(nfc=NFC, nbr_cls=NBR_CLS)

    if PRETRAINED_AE_PATH is None:
        PRETRAINED_AE_PATH = 'train_results/' + 'AE_' + TRIAL_NAME + '/models/%d.pth' % ITERATION_AE
    else:
        from config import PRETRAINED_AE_ITER
        PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER

    net_ae.load_state_dicts(PRETRAINED_AE_PATH)
    net_ae.cuda()
    net_ae.eval()

    RefineGenerator = None
    if DATA_NAME == 'celeba':
        from models import RefineGenerator_face as RefineGenerator
    elif DATA_NAME == 'art' or DATA_NAME == 'shoe':
        from models import RefineGenerator_art as RefineGenerator
    net_ig = RefineGenerator(nfc=NFC, im_size=IM_SIZE_GAN).cuda()
    net_id = Discriminator(nc=3).cuda(
    )  # we use the patch_gan, so the im_size for D should be 512 even if training image size is 1024

    if MULTI_GPU:
        net_ae = nn.DataParallel(net_ae)
        net_ig = nn.DataParallel(net_ig)
        net_id = nn.DataParallel(net_id)

    net_ig_ema = copy_G_params(net_ig)

    opt_ig = optim.Adam(net_ig.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_id = optim.Adam(net_id.parameters(), lr=2e-4, betas=(0.5, 0.999))

    if GAN_CKECKPOINT is not None:
        ckpt = torch.load(GAN_CKECKPOINT)
        net_ig.load_state_dict(ckpt['ig'])
        net_id.load_state_dict(ckpt['id'])
        net_ig_ema = ckpt['ig_ema']
        opt_ig.load_state_dict(ckpt['opt_ig'])
        opt_id.load_state_dict(ckpt['opt_id'])

    ## create a log file
    losses_g_img = AverageMeter()
    losses_d_img = AverageMeter()
    losses_mse = AverageMeter()
    losses_rec_s = AverageMeter()

    losses_rec_ae = AverageMeter()

    fixed_skt = fixed_rgb = fixed_perm = None

    fid = [[0, 0]]

    for epoch in range(EPOCH_GAN):
        for iteration in tqdm(range(10000)):
            rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader)

            rgb_img = rgb_img.cuda()

            rd = random.randint(0, 3)
            if rd == 0:
                skt_img = skt_img_1.cuda()
            elif rd == 1:
                skt_img = skt_img_2.cuda()
            else:
                skt_img = skt_img_3.cuda()

            if iteration == 0:
                fixed_skt = skt_img_3[:8].clone().cuda()
                fixed_rgb = rgb_img[:8].clone()
                fixed_perm = true_randperm(fixed_rgb.shape[0], 'cuda')

            ### 1. train D
            gimg_ae, style_feats = net_ae(skt_img, rgb_img)
            g_image = net_ig(gimg_ae, style_feats)

            pred_r = net_id(rgb_img)
            pred_f = net_id(g_image.detach())

            loss_d = d_hinge_loss(pred_r, pred_f)

            net_id.zero_grad()
            loss_d.backward()
            opt_id.step()

            loss_rec_ae = F.mse_loss(gimg_ae, rgb_img) + F.l1_loss(
                gimg_ae, rgb_img)
            losses_rec_ae.update(loss_rec_ae.item(), BATCH_SIZE_GAN)

            ### 2. train G
            pred_g = net_id(g_image)
            loss_g = g_hinge_loss(pred_g)

            if DATA_NAME == 'shoe':
                loss_mse = 10 * (F.l1_loss(g_image, rgb_img) +
                                 F.mse_loss(g_image, rgb_img))
            else:
                loss_mse = 10 * percept(
                    F.adaptive_avg_pool2d(g_image, output_size=256),
                    F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum()
            losses_mse.update(loss_mse.item() / BATCH_SIZE_GAN, BATCH_SIZE_GAN)

            loss_all = loss_g + loss_mse

            if DATA_NAME == 'shoe':
                ### the grey image reconstruction
                perm = true_randperm(BATCH_SIZE_GAN)
                img_ae_perm, style_feats_perm = net_ae(skt_img, rgb_img[perm])

                gimg_grey = net_ig(img_ae_perm, style_feats_perm)
                gimg_grey = gimg_grey.mean(dim=1, keepdim=True)
                real_grey = rgb_img.mean(dim=1, keepdim=True)
                loss_rec_grey = F.mse_loss(gimg_grey, real_grey)
                loss_all += 10 * loss_rec_grey

            net_ig.zero_grad()
            loss_all.backward()
            opt_ig.step()

            for p, avg_p in zip(net_ig.parameters(), net_ig_ema):
                avg_p.mul_(0.999).add_(p.data, alpha=0.001)

            ### 3. logging
            losses_g_img.update(pred_g.mean().item(), BATCH_SIZE_GAN)
            losses_d_img.update(pred_r.mean().item(), BATCH_SIZE_GAN)

            if iteration % SAVE_IMAGE_INTERVAL == 0:  #show the current images
                with torch.no_grad():

                    backup_para_g = copy_G_params(net_ig)
                    load_params(net_ig, net_ig_ema)

                    gimg_ae, style_feats = net_ae(fixed_skt, fixed_rgb)
                    gmatch = net_ig(gimg_ae, style_feats)

                    gimg_ae_perm, style_feats = net_ae(fixed_skt,
                                                       fixed_rgb[fixed_perm])
                    gmismatch = net_ig(gimg_ae_perm, style_feats)

                    gimg = torch.cat([
                        F.interpolate(fixed_rgb, IM_SIZE_GAN),
                        F.interpolate(fixed_skt.repeat(1, 3, 1, 1),
                                      IM_SIZE_GAN), gmatch,
                        F.interpolate(gimg_ae, IM_SIZE_GAN), gmismatch,
                        F.interpolate(gimg_ae_perm, IM_SIZE_GAN)
                    ])

                    vutils.save_image(
                        gimg,
                        f'{saved_image_folder}/img_iter_{epoch}_{iteration}.jpg',
                        normalize=True,
                        range=(-1, 1))
                    del gimg

                    make_matrix(
                        dataset_rgb, dataset_skt, net_ae, net_ig, 5,
                        f'{saved_image_folder}/img_iter_{epoch}_{iteration}_matrix.jpg'
                    )

                    load_params(net_ig, backup_para_g)

            if iteration % LOG_INTERVAL == 0:
                log_msg = 'Iter: [{0}/{1}] G: {losses_g_img.avg:.4f}  D: {losses_d_img.avg:.4f}  MSE: {losses_mse.avg:.4f}  Rec: {losses_rec_s.avg:.5f}  FID: {fid:.4f}'.format(
                    epoch,
                    iteration,
                    losses_g_img=losses_g_img,
                    losses_d_img=losses_d_img,
                    losses_mse=losses_mse,
                    losses_rec_s=losses_rec_s,
                    fid=fid[-1][0])

                print(log_msg)
                print('%.5f' % (losses_rec_ae.avg))

                if log_file_path is not None:
                    log_file = open(log_file_path, 'a')
                    log_file.write(log_msg + '\n')
                    log_file.close()

                losses_g_img.reset()
                losses_d_img.reset()
                losses_mse.reset()
                losses_rec_s.reset()
                losses_rec_ae.reset()

            if iteration % SAVE_MODEL_INTERVAL == 0 or iteration + 1 == 10000:
                print('Saving history model')
                torch.save(
                    {
                        'ig': net_ig.state_dict(),
                        'id': net_id.state_dict(),
                        'ae': net_ae.state_dict(),
                        'ig_ema': net_ig_ema,
                        'opt_ig': opt_ig.state_dict(),
                        'opt_id': opt_id.state_dict(),
                    }, '%s/%d.pth' % (saved_model_folder, epoch))

            if iteration % FID_INTERVAL == 0 and iteration > 1:
                print("calculating FID ...")
                fid_batch_images = FID_BATCH_NBR
                if real_features is None:
                    if os.path.exists('%s_fid_feats.npy' % (DATA_NAME)):
                        real_features = pickle.load(
                            open('%s_fid_feats.npy' % (DATA_NAME), 'rb'))
                    else:
                        real_features = extract_feature_from_generator_fn(
                            real_image_loader(dataloader,
                                              n_batches=fid_batch_images),
                            inception)
                        real_mean = np.mean(real_features, 0)
                        real_cov = np.cov(real_features, rowvar=False)
                        pickle.dump(
                            {
                                'feats': real_features,
                                'mean': real_mean,
                                'cov': real_cov
                            }, open('%s_fid_feats.npy' % (DATA_NAME), 'wb'))
                        real_features = pickle.load(
                            open('%s_fid_feats.npy' % (DATA_NAME), 'rb'))

                sample_features = extract_feature_from_generator_fn(
                    image_generator(dataset,
                                    net_ae,
                                    net_ig,
                                    n_batches=fid_batch_images),
                    inception,
                    total=fid_batch_images)
                cur_fid = calc_fid(sample_features,
                                   real_mean=real_features['mean'],
                                   real_cov=real_features['cov'])
                sample_features_perm = extract_feature_from_generator_fn(
                    image_generator_perm(dataset,
                                         net_ae,
                                         net_ig,
                                         n_batches=fid_batch_images),
                    inception,
                    total=fid_batch_images)
                cur_fid_perm = calc_fid(sample_features_perm,
                                        real_mean=real_features['mean'],
                                        real_cov=real_features['cov'])

                fid.append([cur_fid, cur_fid_perm])
                print('fid:', fid)
                if log_file_path is not None:
                    log_file = open(log_file_path, 'a')
                    log_msg = 'fid: %.5f, %.5f' % (fid[-1][0], fid[-1][1])
                    log_file.write(log_msg + '\n')
                    log_file.close()
Пример #21
0
from datasets import ImageFolder
import cfg


parser = argparse.ArgumentParser(description='PyTorch RPC Predicting')
parser.add_argument('-i', '--image', metavar='PATH', type=str, help='path to image')                    
parser.add_argument('-w', '--model_weight', default=cfg.WEIGHTS_PATH, type=str, metavar='PATH',
                    help='path to latest checkpoint')
args = parser.parse_args()
print(args)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:{}'.format(device))

# classes' name
dataset = ImageFolder(cfg.VAL_PATH)
class_to_idx = dataset.class_to_idx
idx_to_class = dict(zip(class_to_idx.values(), class_to_idx.keys()))
print(idx_to_class)

# load data
def load_image(imgPath):
    img = Image.open(imgPath).convert('RGB')
    trans = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
    ])
    return trans(img)
Пример #22
0
    'momentum': 0.9,
    'snapshot': ''
}

joint_transform = joint_transforms.Compose([
    joint_transforms.RandomCrop(300),
    joint_transforms.RandomHorizontallyFlip(),
    joint_transforms.RandomRotate(10)
])
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
target_transform = transforms.ToTensor()

train_set = ImageFolder(cvpr2014_trainning_path, joint_transform,
                        img_transform, target_transform)
train_loader = DataLoader(train_set,
                          batch_size=args['train_batch_size'],
                          num_workers=12,
                          shuffle=True)

criterion = nn.BCEWithLogitsLoss().cuda()
log_path = os.path.join(ckpt_path, exp_name,
                        str(datetime.datetime.now()) + '.txt')


def main():
    net = BR2Net().cuda().train()

    optimizer = optim.SGD([{
        'params': [
Пример #23
0
    'snapshot': ''
}
##########################data augmentation###############################
joint_transform = joint_transforms.Compose([
    joint_transforms.RandomCrop(384, 384),
    joint_transforms.RandomHorizontallyFlip(),
    joint_transforms.RandomRotate(10)
])
img_transform = transforms.Compose([
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
target_transform = transforms.ToTensor()
##########################################################################
train_set = ImageFolder(train_data, joint_transform, img_transform,
                        target_transform)
train_loader = DataLoader(train_set,
                          batch_size=args['train_batch_size'],
                          num_workers=12,
                          shuffle=True)
criterion = nn.BCEWithLogitsLoss().cuda()
criterion_BCE = nn.BCELoss().cuda()
criterion_MAE = nn.L1Loss().cuda()
criterion_MSE = nn.MSELoss().cuda()
log_path = os.path.join(ckpt_path, exp_name,
                        str(datetime.datetime.now()) + '.txt')


def main():
    model = RGBD_sal()
    net = model.cuda().train()
Пример #24
0
def execute(experiment):
    print('EXECUTE')
    print(experiment['parameters'])
    time.sleep(3)
    parameters = experiment['parameters'].copy()
    raw_results = experiment['raw_results']
    args.experiment_path = os.path.join("artifacts", parameters['dataset'], parameters['architecture'])


    ######################################
    # Initial configuration...
    ######################################

    float_parameters = [
        'test_set_split',
        'reduce_train_set',
        'validation_set_split',
        'sgd_lr',
        'sgd_momentum',
        'sgd_weight_decay',
        'sgd_dampening',
        'adam_lr',
        'adam_beta1',
        'adam_beta2',
        'adam_eps',
        'adam_weight_decay',
        'rmsprop_lr',
        'rmsprop_momentum',
        'rmsprop_alpha',
        'rmsprop_eps',
        'rmsprop_weight_decay',
        'adagrad_lr',
        'adagrad_learning_decay',
        'adagrad_weight_decay',
        'adagrad_initial_acumulator',
        'tas_alpha',
        'tas_beta',
        'tas_gamma']

    int_parameters = [
        'epochs',
        'batch_size',
        'executions',
        'base_seed']

    bool_parameters = [
        'do_validation_set',
        'combine_datasets',
        'sgd_nesterov',
        'adam_amsgrad',
        'rmsprop_centered']


    for float_parameter in float_parameters:
        if float_parameter in parameters.keys():
            parameters[float_parameter] = float(parameters[float_parameter])

    for int_parameter in int_parameters:
        if int_parameter in parameters.keys():
            parameters[int_parameter] = int(parameters[int_parameter])


    for bool_parameter in bool_parameters:
        if bool_parameter in parameters.keys():
            parameters[bool_parameter] = parameters[bool_parameter] == 'True'

    # Using seeds...
    random.seed(parameters['base_seed'])
    numpy.random.seed(parameters['base_seed'])
    torch.manual_seed(parameters['base_seed'])
    torch.cuda.manual_seed(parameters['base_seed'])
    args.execution_seed = parameters['base_seed'] + args.execution
    print("EXECUTION SEED:", args.execution_seed)

    print(args.dataset)
    # Configuring args and dataset...
    if args.dataset == "mnist":
        args.number_of_dataset_classes = 10
        args.number_of_model_classes = args.number_of_dataset_classes
        normalize = transforms.Normalize((0.1307,), (0.3081,))
        train_transform = transforms.Compose(
            [transforms.ToTensor(), normalize])
        inference_transform = transforms.Compose([transforms.ToTensor(), normalize])
        dataset_path = args.dataset_dir if args.dataset_dir else "datasets/mnist"
        train_set = torchvision.datasets.MNIST(root=dataset_path, train=True, download=True, transform=train_transform)
        test_set = torchvision.datasets.MNIST(root=dataset_path, train=False, download=True, transform=inference_transform)
    elif args.dataset == "cifar10":
        args.number_of_dataset_classes = 10
        args.number_of_model_classes = args.number_of_dataset_classes
        normalize = transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))
        train_transform = transforms.Compose(
            [transforms.RandomCrop(32, padding=4),
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor(), normalize])
        inference_transform = transforms.Compose([transforms.ToTensor(), normalize])
        dataset_path = args.dataset_dir if args.dataset_dir else "datasets/cifar10"
        train_set = torchvision.datasets.CIFAR10(root=dataset_path, train=True, download=True, transform=train_transform)
        test_set = torchvision.datasets.CIFAR10(root=dataset_path, train=False, download=True, transform=inference_transform)
    elif args.dataset == "cifar100":
        args.number_of_dataset_classes = 100
        args.number_of_model_classes = args.number_of_dataset_classes
        normalize = transforms.Normalize((0.507, 0.486, 0.440), (0.267, 0.256, 0.276))
        train_transform = transforms.Compose(
            [transforms.RandomCrop(32, padding=4),
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor(), normalize])
        inference_transform = transforms.Compose([transforms.ToTensor(), normalize])
        dataset_path = args.dataset_dir if args.dataset_dir else "datasets/cifar100"
        train_set = torchvision.datasets.CIFAR100(root=dataset_path, train=True, download=True, transform=train_transform)
        test_set = torchvision.datasets.CIFAR100(root=dataset_path, train=False, download=True, transform=inference_transform)
    else:
        args.number_of_dataset_classes = 1000
        args.number_of_model_classes = args.number_of_model_classes if args.number_of_model_classes else 1000
        if args.arch.startswith('inception'):
            size = (299, 299)
        else:
            size = (224, 256)
        normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        train_transform = transforms.Compose(
            [transforms.RandomResizedCrop(size[0]),  # 224 , 299
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor(), normalize])
        inference_transform = transforms.Compose(
            [transforms.Resize(size[1]),  # 256
             transforms.CenterCrop(size[0]),  # 224 , 299
             transforms.ToTensor(), normalize])
        dataset_path = args.dataset_dir if args.dataset_dir else "/mnt/ssd/imagenet_scripts/2012/images"
        train_path = os.path.join(dataset_path, 'train')
        val_path = os.path.join(dataset_path, 'val')
        train_set = ImageFolder(train_path, transform=train_transform)
        test_set = ImageFolder(val_path, transform=inference_transform)

    # Preparing paths...
    # TODO make execution path an input
    args.execution_path = 'results'
    if not os.path.exists(args.execution_path):
        os.makedirs(args.execution_path)


    ######################################
    # Preparing data...
    ######################################
    if parameters['combine_datasets']:
        complete_dataset = torch.utils.data.ConcatDataset((train_set, test_set))
        train_set, test_set = torch.utils.data.random_split(complete_dataset,
            [round((1 - parameters['test_set_split'])*len(complete_dataset)),
            round(parameters['test_set_split']*len(complete_dataset))])
    if parameters['do_validation_set']:
        train_set, validation_set = torch.utils.data.random_split(train_set,
            [round((1 - parameters['validation_set_split'])*len(complete_dataset)),
            round(parameters['validation_set_split']*len(complete_dataset))])


    if parameters['reduce_train_set'] != 1.0:
        train_set, _ = torch.utils.data.random_split(train_set, [round(parameters['reduce_train_set'] * len(train_set)), round((1 - parameters['reduce_train_set']) * len(train_set))])
        test_set, _ = torch.utils.data.random_split(test_set, [round(parameters['reduce_train_set'] * len(test_set)), round((1 - parameters['reduce_train_set']) * len(test_set))])


    # TODO make shuffle a general parameter
    train_loader = DataLoader(train_set,
                              batch_size=parameters['batch_size'],
                              num_workers=args.workers,
                              shuffle=True)

    test_loader = DataLoader(test_set,
                            batch_size=parameters['batch_size'],
                            num_workers=args.workers,
                            shuffle=True)

    print("\n$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
    print("TRAINSET LOADER SIZE: ====>>>> ", len(train_loader.sampler))
    print("TESTSET LOADER SIZE: ====>>>> ", len(test_loader.sampler))
    print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")

    # Dataset created...
    print("\nDATASET:", args.dataset)

    # create model
    torch.manual_seed(args.execution_seed)
    torch.cuda.manual_seed(args.execution_seed)
    print("=> creating model '{}'".format(parameters['architecture']))
    # model = create_model()
    model = models.__dict__[parameters['architecture']](num_classes=args.number_of_model_classes)
    model.cuda()
    print("\nMODEL:", model)
    torch.manual_seed(args.base_seed)
    torch.cuda.manual_seed(args.base_seed)
    #########################################
    # Training...
    #########################################

    # define loss function (criterion)...
    criterion = nn.CrossEntropyLoss().cuda()

    # define optimizer..

    if parameters['training_method'] == 'sgd':

        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=parameters['sgd_lr'],
                                    momentum=parameters['sgd_momentum'],
                                    weight_decay=parameters['sgd_weight_decay'],
                                    nesterov=parameters['sgd_nesterov'])

    elif parameters['training_method'] == 'adam':
        print('****************AMSGRAD*************')
        print(parameters['adam_amsgrad'])
        time.sleep(3)
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=parameters['adam_lr'],
                                     betas=[parameters['adam_beta1'], parameters['adam_beta2']],
                                     eps=parameters['adam_eps'],
                                     weight_decay=parameters['adam_weight_decay'],
                                     amsgrad=parameters['adam_amsgrad'])

    elif parameters['training_method'] == 'adagrad':

        optimizer = torch.optim.Adagrad(model.parameters(),
                                     lr=parameters['adagrad_lr'],
                                     lr_decay=parameters['adagrad_learning_decay'],
                                     weight_decay=parameters['adagrad_weight_decay'])

    elif parameters['training_method'] == 'rmsprop':

        optimizer = torch.optim.RMSprop(model.parameters(),
                                     lr=parameters['rmsprop_lr'],
                                     momentum=parameters['rmsprop_momentum'],
                                     alpha=parameters['rmsprop_alpha'],
                                     eps=parameters['rmsprop_eps'],
                                     centered=parameters['rmsprop_centered'],
                                     weight_decay=parameters['rmsprop_weight_decay'])

    #optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # , weight_decay=5e-4)

    # define scheduler...
    #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.2, verbose=True,
    #                         
    #                               threshold=0.05, threshold_mode='rel')

    print(parameters)
    if parameters['learning_method'] == 'tas':
        alpha = parameters['tas_alpha']
        beta = parameters['tas_beta']
        gamma = parameters['tas_gamma']
        our_lambda = lambda epoch: (1 - gamma)/(1 + math.exp(alpha*(epoch/parameters['epochs']-beta))) + gamma
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=our_lambda)

    if parameters['learning_method'] == 'fixed_interval':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=parameters['fixed_interval_rate'],
                                                    gamma=parameters['fixed_interval_period'])

    if parameters['learning_method'] == 'fixed_epochs':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         milestones=parameters['fixed_epochs_milestones'],
                                                         gamma=parameters['fixed_epochs_rate'])

    if parameters['learning_method'] == 'constant':
        scheduler = None
    # model.initialize_parameters() ####### It works for AlexNet_, LeNet and VGG...
    # initialize_parameters(model)

    print("\n################ TRAINING ################")
    best_model_file_path = os.path.join(args.execution_path, 'best_model.pth.tar')
    best_train_acc1, best_val_acc1, final_train_loss = \
        train_val(parameters, raw_results, train_loader, test_loader, model, criterion, optimizer,
                  scheduler, best_model_file_path)

    # save to json file

    return best_train_acc1, best_val_acc1, final_train_loss, raw_results
Пример #25
0
def main(config):
    result_name = '{}_{}_{}way_{}shot'.format(
        config['data_name'],
        config['arch']['base_model'],
        config['general']['way_num'],
        config['general']['shot_num'],
    )
    save_path = os.path.join(config['general']['save_root'], result_name)
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    fout_path = os.path.join(save_path, 'train_info.txt')
    fout_file = open(fout_path, 'a+')
    with open(os.path.join(save_path, 'config.json'), 'w') as handle:
        json.dump(config, handle, indent=4, sort_keys=True)
    print_func(config, fout_file)

    train_trsfms = transforms.Compose([
        transforms.Resize((config['general']['image_size'],
                           config['general']['image_size'])),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    val_trsfms = transforms.Compose([
        transforms.Resize((config['general']['image_size'],
                           config['general']['image_size'])),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    model = ALTNet(**config['arch'])
    print_func(model, fout_file)

    optimizer = optim.Adam(model.parameters(), lr=config['train']['optim_lr'])

    if config['train']['lr_scheduler']['name'] == 'StepLR':
        lr_scheduler = optim.lr_scheduler.StepLR(
            optimizer=optimizer, **config['train']['lr_scheduler']['args'])
    elif config['train']['lr_scheduler']['name'] == 'MultiStepLR':
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer=optimizer, **config['train']['lr_scheduler']['args'])
    else:
        raise RuntimeError

    if config['train']['loss']['name'] == 'CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss(**config['train']['loss']['args'])
    else:
        raise RuntimeError

    device, _ = prepare_device(config['n_gpu'])
    model = model.to(device)
    criterion = criterion.to(device)

    best_val_prec1 = 0
    best_test_prec1 = 0
    for epoch_index in range(config['train']['epochs']):
        print_func('{} Epoch {} {}'.format('=' * 35, epoch_index, '=' * 35),
                   fout_file)
        train_dataset = ImageFolder(
            data_root=config['general']['data_root'],
            mode='train',
            episode_num=config['train']['episode_num'],
            way_num=config['general']['way_num'],
            shot_num=config['general']['shot_num'],
            query_num=config['general']['query_num'],
            transform=train_trsfms,
        )
        val_dataset = ImageFolder(
            data_root=config['general']['data_root'],
            mode='val',
            episode_num=config['test']['episode_num'],
            way_num=config['general']['way_num'],
            shot_num=config['general']['shot_num'],
            query_num=config['general']['query_num'],
            transform=val_trsfms,
        )
        test_dataset = ImageFolder(
            data_root=config['general']['data_root'],
            mode='test',
            episode_num=config['test']['episode_num'],
            way_num=config['general']['way_num'],
            shot_num=config['general']['shot_num'],
            query_num=config['general']['query_num'],
            transform=val_trsfms,
        )

        print_func(
            'The num of the train_dataset: {}'.format(len(train_dataset)),
            fout_file)
        print_func('The num of the val_dataset: {}'.format(len(val_dataset)),
                   fout_file)
        print_func('The num of the test_dataset: {}'.format(len(test_dataset)),
                   fout_file)

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=config['train']['batch_size'],
            shuffle=True,
            num_workers=config['general']['workers_num'],
            drop_last=True,
            pin_memory=True)
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=config['test']['batch_size'],
            shuffle=True,
            num_workers=config['general']['workers_num'],
            drop_last=True,
            pin_memory=True)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=config['test']['batch_size'],
            shuffle=True,
            num_workers=config['general']['workers_num'],
            drop_last=True,
            pin_memory=True)

        # train for 5000 episodes in each epoch
        print_func('============ Train on the train set ============',
                   fout_file)
        train(train_loader, model, criterion, optimizer, epoch_index, device,
              fout_file, config['general']['image2level'],
              config['general']['print_freq'])

        print_func('============ Validation on the val set ============',
                   fout_file)
        val_prec1 = validate(val_loader, model, criterion, epoch_index, device,
                             fout_file, config['general']['image2level'],
                             config['general']['print_freq'])
        print_func(
            ' * Prec@1 {:.3f} Best Prec1 {:.3f}'.format(
                val_prec1, best_val_prec1), fout_file)

        print_func('============ Testing on the test set ============',
                   fout_file)
        test_prec1 = validate(test_loader, model, criterion, epoch_index,
                              device, fout_file,
                              config['general']['image2level'],
                              config['general']['print_freq'])
        print_func(
            ' * Prec@1 {:.3f} Best Prec1 {:.3f}'.format(
                test_prec1, best_test_prec1), fout_file)

        if val_prec1 > best_val_prec1:
            best_val_prec1 = val_prec1
            best_test_prec1 = test_prec1
            save_model(model,
                       save_path,
                       config['data_name'],
                       epoch_index,
                       is_best=True)

        if epoch_index % config['general'][
                'save_freq'] == 0 and epoch_index != 0:
            save_model(model,
                       save_path,
                       config['data_name'],
                       epoch_index,
                       is_best=False)

        lr_scheduler.step()

    print_func('............Training is end............', fout_file)
Пример #26
0
 image_transform = transforms.Compose([
     transforms.Scale(int(imsize * 76 / 64)),
     transforms.RandomCrop(imsize),
     transforms.RandomHorizontalFlip()
 ])
 if cfg.DATA_DIR.find('lsun') != -1:
     from datasets import LSUNClass
     dataset = LSUNClass('%s/%s_%s_lmdb' %
                         (cfg.DATA_DIR, cfg.DATASET_NAME, split_dir),
                         base_size=cfg.TREE.BASE_SIZE,
                         transform=image_transform)
 elif cfg.DATA_DIR.find('imagenet') != -1:
     from datasets import ImageFolder
     dataset = ImageFolder(cfg.DATA_DIR,
                           split_dir='train',
                           custom_classes=CLASS_DIC[cfg.DATASET_NAME],
                           base_size=cfg.TREE.BASE_SIZE,
                           transform=image_transform)
 elif cfg.GAN.B_CONDITION:  # text to image task
     if cfg.DATASET_NAME == 'birds':
         from datasets import TextDataset
         dataset = TextDataset(cfg.DATA_DIR,
                               split_dir,
                               base_size=cfg.TREE.BASE_SIZE,
                               transform=image_transform)
     elif cfg.DATASET_NAME == 'flowers':
         from datasets import FlowersDataset
         dataset = FlowersDataset(cfg.DATA_DIR,
                                  split_dir,
                                  base_size=cfg.TREE.BASE_SIZE,
                                  transform=image_transform)
Пример #27
0
def main(result_path, epoch_num):
    config = json.load(open(os.path.join(result_path, 'config.json')))

    fout_path = os.path.join(result_path, 'test_info.txt')
    fout_file = open(fout_path, 'a+')
    print_func(config, fout_file)

    trsfms = transforms.Compose([
        transforms.Resize((config['general']['image_size'],
                           config['general']['image_size'])),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    model = ALTNet(**config['arch'])
    print_func(model, fout_file)

    state_dict = torch.load(
        os.path.join(result_path,
                     '{}_best_model.pth'.format(config['data_name'])))
    model.load_state_dict(state_dict)

    if config['train']['loss']['name'] == 'CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss(**config['train']['loss']['args'])
    else:
        raise RuntimeError

    device, _ = prepare_device(config['n_gpu'])
    model = model.to(device)
    criterion = criterion.to(device)

    total_accuracy = 0.0
    total_h = np.zeros(epoch_num)
    total_accuracy_vector = []
    for epoch_idx in range(epoch_num):
        test_dataset = ImageFolder(
            data_root=config['general']['data_root'],
            mode='test',
            episode_num=600,
            way_num=config['general']['way_num'],
            shot_num=config['general']['shot_num'],
            query_num=config['general']['query_num'],
            transform=trsfms,
        )

        print_func('The num of the test_dataset: {}'.format(len(test_dataset)),
                   fout_file)

        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=config['test']['batch_size'],
            shuffle=True,
            num_workers=config['general']['workers_num'],
            drop_last=True,
            pin_memory=True)

        print_func('============ Testing on the test set ============',
                   fout_file)
        _, accuracies = validate(test_loader, model, criterion, epoch_idx,
                                 device, fout_file,
                                 config['general']['image2level'],
                                 config['general']['print_freq'])
        test_accuracy, h = mean_confidence_interval(accuracies)
        print_func("Test Accuracy: {}\t h: {}".format(test_accuracy, h[0]),
                   fout_file)

        total_accuracy += test_accuracy
        total_accuracy_vector.extend(accuracies)
        total_h[epoch_idx] = h

    aver_accuracy, _ = mean_confidence_interval(total_accuracy_vector)
    print_func(
        'Aver Accuracy: {:.3f}\t Aver h: {:.3f}'.format(
            aver_accuracy, total_h.mean()), fout_file)
    print_func('............Testing is end............', fout_file)
Пример #28
0
	def __init__(self, options):
		super(Stage2Trainer, self).__init__(options, subfolders = ['samples', 'reconstructions'], copy_keys = copy_keys)

		transforms = []
		if options.crop_size is not None:
			transforms.append(T.CenterCrop(options.crop_size))
		transforms.append(T.Resize(options.image_size))
		transforms.append(T.ToTensor())
		transforms.append(T.Normalize((0.5, 0.5, 0.5, 0), (0.5, 0.5, 0.5, 1)))

		image_set = ImageFolder(options.data_root, transform = T.Compose(transforms))
		enc_codes = torch.load(os.path.join(options.enc_path, 'codes', '{0}_codes.pt'.format(options.enc_iter)))
		code_set = torch.utils.data.TensorDataset(enc_codes[:, 0], enc_codes[:, 1])
		self.dataset = ParallelDataset(image_set, code_set)
		self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size = options.batch_size, shuffle = True, drop_last = True, num_workers = options.nloader)
		self.data_iter = iter(self.dataloader)

		enc_stats = torch.load(os.path.join(options.enc_path, 'codes', '{0}_stats.pt'.format(options.enc_iter)))
		self.con_full_mean = enc_stats['full_mean']
		self.con_full_std = enc_stats['full_std']
		self.con_eigval = enc_stats['eigval']
		self.con_eigvec = enc_stats['eigvec']
		self.dim_weight = enc_stats['dim_weight']

		if self.con_weight > 0:
			self.enc = models.Encoder(options.image_size, options.image_size, options.enc_features, options.enc_blocks, options.enc_adain_features, options.enc_adain_blocks, options.content_size)
			self.enc.to(self.device)
			self.enc.load_state_dict(torch.load(os.path.join(options.enc_path, 'models', '{0}_enc.pt'.format(options.enc_iter)), map_location = self.device))

		self.gen = models.TwoPartNestedDropoutGenerator(options.image_size, options.image_size, options.gen_features, options.gen_blocks, options.gen_adain_features, options.gen_adain_blocks, options.content_size, options.style_size)
		self.gen.to(self.device)
		if (self.load_path is None) and not options.reset_gen:
			self.gen.load_state_dict(torch.load(os.path.join(options.enc_path, 'models', '{0}_gen.pt'.format(options.enc_iter)), map_location = self.device))
		self.gen_optim = optim.RMSprop(self.gen.parameters(), lr = self.lr, eps = 1e-4)
		self.add_model('gen', self.gen, self.gen_optim)

		self.cla = models.ClassifierOrDiscriminator(options.image_size, options.image_size, options.cla_features, options.cla_blocks, options.cla_adain_features, options.cla_adain_blocks, self.nclass)
		self.cla.to(self.device)
		if (self.load_path is None) and (options.cla_path is not None) and not options.reset_cla:
			self.cla.load_state_dict(torch.load(os.path.join(options.cla_path, 'models', '{0}_cla.pt'.format(options.cla_iter)), map_location = self.device))
		self.cla_optim = optim.RMSprop(self.cla.parameters(), lr = self.lr, eps = 1e-4)
		self.add_model('cla', self.cla, self.cla_optim)

		self.dis = models.ClassifierOrDiscriminator(options.image_size, options.image_size, options.dis_features, options.dis_blocks, options.dis_adain_features, options.dis_adain_blocks, self.nclass)
		self.dis.to(self.device)
		if (self.load_path is None) and (options.cla_path is not None) and not options.reset_dis:
			self.dis.load_state_dict(torch.load(os.path.join(options.cla_path, 'models', '{0}_cla.pt'.format(options.cla_iter)), map_location = self.device))
		self.dis.convert()
		self.dis_optim = optim.RMSprop(self.dis.parameters(), lr = self.lr, eps = 1e-4)
		self.add_model('dis', self.dis, self.dis_optim)

		self.sty = models.NormalizedStyleBank(self.nclass, options.style_size, image_set.get_class_freq())
		self.sty.to(self.device)
		if (self.load_path is None) and not options.reset_sty:
			self.sty.load_state_dict(torch.load(os.path.join(options.enc_path, 'models', '{0}_sty.pt'.format(options.enc_iter)), map_location = self.device))
		self.sty_optim = optim.Adam(self.sty.parameters(), lr = self.sty_lr, eps = 1e-8)
		self.add_model('sty', self.sty, self.sty_optim)

		if self.load_path is not None:
			rec_images = torch.load(os.path.join(self.load_path, 'reconstructions', 'images.pt'), map_location = self.device)
			self.rec_codes = torch.load(os.path.join(self.load_path, 'reconstructions', 'codes.pt'), map_location = self.device)
			self.rec_labels = torch.load(os.path.join(self.load_path, 'reconstructions', 'labels.pt'), map_location = self.device)
			self.vis_codes = torch.load(os.path.join(self.load_path, 'samples', 'codes.pt'), map_location = self.device)
			self.vis_style_noise = torch.load(os.path.join(self.load_path, 'samples', 'style_noise.pt'), map_location = self.device)
			self.load(options.load_iter)
		else:
			rec_images = []
			rec_codes = []
			rec_labels = []
			rec_index = random.sample(range(len(self.dataset)), options.vis_row * options.vis_col)
			for k in rec_index:
				image, label, code, _ = self.dataset[k]
				rec_images.append(image)
				rec_codes.append(code)
				rec_labels.append(label)
			rec_images = torch.stack(rec_images, dim = 0)
			self.rec_codes = torch.stack(rec_codes, dim = 0).to(self.device)
			self.rec_labels = one_hot(torch.tensor(rec_labels, dtype = torch.int32), self.nclass).to(self.device)
			self.vis_codes = self.noise_to_con_code(gaussian_noise(options.vis_row * options.vis_col, options.content_size)).to(self.device)
			self.vis_style_noise = gaussian_noise(options.vis_row * options.vis_col, options.style_size).to(self.device)

			self.state.dis_total_batches = 0

		if self.save_path != self.load_path:
			torch.save(rec_images, os.path.join(self.save_path, 'reconstructions', 'images.pt'))
			torch.save(self.rec_codes, os.path.join(self.save_path, 'reconstructions', 'codes.pt'))
			torch.save(self.rec_labels, os.path.join(self.save_path, 'reconstructions', 'labels.pt'))
			torch.save(self.vis_codes, os.path.join(self.save_path, 'samples', 'codes.pt'))
			torch.save(self.vis_style_noise, os.path.join(self.save_path, 'samples', 'style_noise.pt'))
			save_image(rec_images.add(1).div(2), os.path.join(self.save_path, 'reconstructions', 'target.png'), self.vis_col)

		self.add_periodic_func(self.visualize_fixed, options.visualize_iter)
		self.visualize_fixed()

		self.loss_avg_factor = 0.9

		self.sty_drop_prob = torch.Tensor(options.style_size)
		for i in range(options.style_size):
			self.sty_drop_prob[i] = options.style_dropout ** i
Пример #29
0
detect_queue = deque([[
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff
],
                      [
                          0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
                          0xff, 0xff, 0xff, 0xff, 0xff
                      ],
                      [
                          0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
                          0xff, 0xff, 0xff, 0xff, 0xff
                      ]])

# Get Dataloader using ImageFolder
os.makedirs(save_path, exist_ok=True)
dataloader = DataLoader(ImageFolder(image_folder, img_size=img_size),
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=8)

# Extracts class labels from file
classes = load_classes('tl.names')

# Initiate model
#model = myNet(num_classes, anchors, input_size, True, is_train = True).cuda()
#model.load_state_dict(torch.load('./weight/final_param.pkl'),False)
model = torch.load('./model_0914_17.pkl', map_location='cuda:0')
model.set_test()

cuda = torch.cuda.is_available()
if cuda: