def __init__(self, args, model: torch.nn.Module, train_dataset: Dataset,
                 test_dataset: Dataset, utils):
        self.utils = utils
        self.args = args
        self.device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')

        self.batch_size = self.args.batch_size
        self.img_size = self.args.img_size

        self.model = model.to(self.device)

        os.makedirs(os.path.join(self.args.ckpt_dir, self.model.name),
                    exist_ok=True)
        os.makedirs(self.args.save_gen_images_dir, exist_ok=True)
        ''' optimizer '''

        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.args.lr)
        '''dataset and dataloader'''
        self.train_dataset = train_dataset
        weights = self.utils.make_weights_for_balanced_classes(
            self.train_dataset.imgs, len(self.train_dataset.classes))
        weights = torch.DoubleTensor(weights)
        sampler = WeightedRandomSampler(weights, len(weights))

        self.train_dataloader = DataLoader(self.train_dataset,
                                           self.batch_size,
                                           num_workers=args.num_worker,
                                           sampler=sampler,
                                           pin_memory=True)

        self.test_dataset = test_dataset
        self.test_dataloader = DataLoader(self.test_dataset,
                                          self.batch_size,
                                          num_workers=args.num_worker,
                                          pin_memory=True)
        '''loss function'''
        self.criterion = VAELoss().to(self.device)
        '''scheduler'''
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=3)
예제 #2
0
파일: train.py 프로젝트: yz27/VAE_ISIC2018
# load checkpoint
if args.resume is not None:
    checkpoint = torch.load(args.resume,
                            map_location=lambda storage, loc: storage)
    print("checkpoint loaded!")
    print("val loss: {}\tepoch: {}\t".format(checkpoint['val_loss'],
                                             checkpoint['epoch']))

# model
model = VAE(args.image_size)
if args.resume is not None:
    model.load_state_dict(checkpoint['state_dict'])

# criterion
criterion = VAELoss(size_average=True, kl_weight=args.kl_weight)
if args.cuda is True:
    model = model.cuda()
    criterion = criterion.cuda()

# load data
train_loader, val_loader = load_vae_train_datasets(input_size=args.image_size,
                                                   data=args.data,
                                                   batch_size=args.batch_size)

# load optimizer and scheduler
opt = torch.optim.Adam(params=model.parameters(),
                       lr=args.lr,
                       betas=(0.9, 0.999))
if args.resume is not None and not args.reset_opt:
    opt.load_state_dict(checkpoint['optimizer'])
예제 #3
0
                    help="weight on KL term")
parser.add_argument('--out_csv', default='result.csv')
args = parser.parse_args()

# load checkpoint
if not os.path.isfile(args.model_path):
    print('%s is not path to a file' % args.model_path)
    exit()
checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
print("checkpoint loaded!")
print("val loss: {}\tepoch: {}\t".format(checkpoint['val_loss'], checkpoint['epoch']))

# model and criterion
model = VAE(args.image_size)
model.load_state_dict(checkpoint['state_dict'])
criterion = VAELoss(size_average=True, kl_weight=args.kl_weight)

if args.cuda:
    model = model.cuda()
    criterion = criterion.cuda()

# load data
test_loader = load_vae_test_datasets(args.image_size, args.data)

############################# ANOMALY SCORE DEF ##########################
def get_vae_score(vae, image, L=5):
    """
    The vae score for a single image, which is basically the loss
    :param image: [1, 3, 256, 256]
    :return (vae loss, KL, reconst_err)
    """