Пример #1
0
    def __init__(self, args,
                 data_id, model_id, optim_id,
                 train_loader, eval_loader,
                 model, optimizer, scheduler_iter, scheduler_epoch):

        # Edit args
        if args.eval_every is None:
            args.eval_every = args.epochs
        if args.check_every is None:
            args.check_every = args.epochs
        if args.name is None:
            args.name = time.strftime("%Y-%m-%d_%H-%M-%S")
        if args.project is None:
            args.project = '_'.join([data_id, model_id])

        # Move model
        model = model.to(args.device)
        if args.parallel == 'dp':
            model = DataParallelDistribution(model)

        # Init parent
        log_path = os.path.join(self.log_base, data_id, model_id, optim_id, args.name)
        super(FlowExperiment, self).__init__(model=model,
                                             optimizer=optimizer,
                                             scheduler_iter=scheduler_iter,
                                             scheduler_epoch=scheduler_epoch,
                                             log_path=log_path,
                                             eval_every=args.eval_every,
                                             check_every=args.check_every)

        # Store args
        self.create_folders()
        self.save_args(args)
        self.args = args

        # Store IDs
        self.data_id = data_id
        self.model_id = model_id
        self.optim_id = optim_id

        # Store data loaders
        self.train_loader = train_loader
        self.eval_loader = eval_loader

        # Init logging
        args_dict = clean_dict(vars(args), keys=self.no_log_keys)
        if args.log_tb:
            self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))
            self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0)
        if args.log_wandb:
            wandb.init(config=args_dict, project=args.project, id=args.name, dir=self.log_path)
Пример #2
0
with open(path_args, 'rb') as f:
    args = pickle.load(f)

##################
## Specify data ##
##################

_, _, data_shape = get_data(args)

###################
## Specify model ##
###################

model = get_model(args, data_shape=data_shape)
if args.parallel == 'dp':
    model = DataParallelDistribution(model)
checkpoint = torch.load(path_check)
model.load_state_dict(checkpoint['model'])
print('Loaded weights for model at {}/{} epochs'.format(
    checkpoint['current_epoch'], args.epochs))

############
## Sample ##
############

path_samples = '{}/samples/sample_ep{}_s{}.png'.format(
    eval_args.model, checkpoint['current_epoch'], eval_args.seed)
if not os.path.exists(os.path.dirname(path_samples)):
    os.mkdir(os.path.dirname(path_samples))

device = 'cuda' if torch.cuda.is_available() else 'cpu'
Пример #3
0
    def __init__(self, args, data_id, model_id, optim_id, train_loader,
                 eval_loader, model, optimizer, scheduler_iter,
                 scheduler_epoch):

        # Edit args
        if args.eval_every is None or args.eval_every == 0:
            args.eval_every = args.epochs
        if args.check_every is None or args.check_every == 0:
            args.check_every = args.epochs
        if args.name is None:
            args.name = time.strftime("%Y-%m-%d_%H-%M-%S")
        if args.project is None:
            args.project = '_'.join([data_id, model_id])

        if args.name == "debug":
            log_path = os.path.join(self.log_base, "debug", data_id, model_id,
                                    optim_id, f"seed{args.seed}",
                                    time.strftime("%Y-%m-%d_%H-%M-%S"))
        else:
            log_path = os.path.join(self.log_base, data_id, model_id, optim_id,
                                    f"seed{args.seed}", args.name)

        # Move model
        model = model.to(args.device)
        if args.parallel == 'dp':
            model = DataParallelDistribution(model)

        # Init parent
        super(FlowExperiment, self).__init__(model=model,
                                             optimizer=optimizer,
                                             scheduler_iter=scheduler_iter,
                                             scheduler_epoch=scheduler_epoch,
                                             log_path=log_path,
                                             eval_every=args.eval_every,
                                             check_every=args.check_every,
                                             save_samples=args.save_samples)

        # Store args
        self.create_folders()
        self.save_args(args)
        self.args = args

        # Store IDs
        self.data_id = data_id
        self.model_id = model_id
        self.optim_id = optim_id

        # Store data loaders
        self.train_loader = train_loader
        self.eval_loader = eval_loader

        # Init logging
        args_dict = clean_dict(vars(args), keys=self.no_log_keys)
        if args.log_tb:
            self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))
            self.writer.add_text("args",
                                 get_args_table(args_dict).get_html_string(),
                                 global_step=0)

        if args.log_wandb:
            wandb.init(config=args_dict,
                       project=args.project,
                       id=args.name,
                       dir=self.log_path)

        # training params
        self.max_grad_norm = args.max_grad_norm

        # automatic mixed precision
        # bigger changes need to make this work with dataparallel though (@autocast() decoration on each forward)
        pytorch_170 = int(str(torch.__version__)[2]) >= 7
        self.amp = args.amp and args.parallel != 'dp' and pytorch_170
        if self.amp:
            # only available in pytorch 1.7.0+
            self.scaler = torch.cuda.amp.GradScaler()
        else:
            self.scaler = None

        # save model architecture for reference
        self.save_architecture()
Пример #4
0
##################
## Specify data ##
##################

eval_loader, data_shape, cond_shape = get_data(args, eval_only=True)

###################
## Specify model ##
###################

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = get_model(args, data_shape=data_shape, cond_shape=cond_shape)
if args.parallel == 'dp':
    model = DataParallelDistribution(model)
checkpoint = torch.load(path_check, map_location=torch.device(device))
model.load_state_dict(checkpoint['model'])
model = model.to(device)
model = model.eval()
print('Loaded weights for model at {}/{} epochs'.format(
    checkpoint['current_epoch'], args.epochs))

############
## Sample ##
############


def save_images(imgs, file_path, num_bits=args.num_bits, nrow=eval_args.nrow):
    if not os.path.exists(os.path.dirname(file_path)):
        os.mkdir(os.path.dirname(file_path))
Пример #5
0
##################
## Specify data ##
##################

eval_loader, data_shape, cond_shape = get_data(args, eval_only=True)

###################
## Specify model ##
###################

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = get_model(args, data_shape=data_shape, cond_shape=cond_shape)
if args.parallel == 'dp':
    model = DataParallelDistribution(model)

checkpoint = torch.load(path_check, map_location=torch.device(device))
model.load_state_dict(checkpoint['model'])
model = model.to(device)
model = model.eval()
print('Loaded weights for model at {}/{} epochs'.format(
    checkpoint['current_epoch'], args.epochs))

############
## Sample ##
############

base_dir = os.path.join(f"{eval_args.model}", f"likelihoods/")
if not os.path.exists(base_dir): os.mkdir(base_dir)
Пример #6
0
##################
## Specify data ##
##################

eval_loader, data_shape, cond_shape = get_data(args, eval_only=True)

###################
## Specify model ##
###################

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = get_model(args, data_shape=data_shape, cond_shape=cond_shape)
if args.parallel == 'dp':
    model = DataParallelDistribution(model)

checkpoint = torch.load(path_check, map_location=torch.device(device))
model.load_state_dict(checkpoint['model'])
model = model.to(device)
model = model.eval()
print('Loaded weights for model at {}/{} epochs'.format(
    checkpoint['current_epoch'], args.epochs))

##############
## Evaluate ##
##############

pq = PerceptualQuality(device=device,
                       num_bits=args.num_bits,
                       sr_scale_factor=args.sr_scale_factor)
Пример #7
0
    def __init__(self, args, data_id, model_id, optim_id, model, teacher,
                 optimizer, scheduler_iter, scheduler_epoch):

        # Edit args
        if args.eval_every is None or args.eval_every == 0:
            args.eval_every = args.epochs
        if args.name is None:
            args.name = time.strftime("%Y-%m-%d_%H-%M-%S")
        if args.project is None:
            args.project = '_'.join([data_id, model_id])

        aug_or_abs = 'abs' if args.augment_size == 0 else f"aug{args.augment_size}"
        hidden = '_'.join([str(u) for u in args.hidden_units])
        cond_id = args.cond_trans.lower()
        arch_id = f"flow_{aug_or_abs}_k{args.num_flows}_h{hidden}_{'affine' if args.affine else 'additive'}{'_actnorm' if args.actnorm else ''}"
        seed_id = f"seed{args.seed}"
        if args.name == "debug":
            log_path = os.path.join(self.log_base, "debug", model_id, data_id,
                                    cond_id, arch_id, seed_id,
                                    time.strftime("%Y-%m-%d_%H-%M-%S"))
        else:
            log_path = os.path.join(self.log_base, model_id, data_id, cond_id,
                                    arch_id, seed_id, args.name)

        # Move models
        model = model.to(args.device)
        if args.parallel == 'dp':
            model = DataParallelDistribution(model)

        # Init parent
        super(StudentExperiment,
              self).__init__(model=model,
                             optimizer=optimizer,
                             scheduler_iter=scheduler_iter,
                             scheduler_epoch=scheduler_epoch,
                             log_path=log_path,
                             eval_every=args.eval_every)
        # student teacher args
        teacher = teacher.to(args.device)
        self.teacher = teacher
        self.teacher.eval()
        self.cond_size = 1 if cond_id.startswith(
            'split') or cond_id.startswith('multiply') else 2

        # Store args
        self.create_folders()
        self.save_args(args)
        self.args = args

        # Store IDs
        self.model_id = model_id
        self.data_id = data_id
        self.cond_id = cond_id
        self.optim_id = optim_id
        self.arch_id = arch_id
        self.seed_id = seed_id

        # Init logging
        args_dict = clean_dict(vars(args), keys=self.no_log_keys)
        if args.log_tb:
            self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))
            self.writer.add_text("args",
                                 get_args_table(args_dict).get_html_string(),
                                 global_step=0)

        if args.log_wandb:
            wandb.init(config=args_dict,
                       project=args.project,
                       id=args.name,
                       dir=self.log_path)

        # training params
        self.max_grad_norm = args.max_grad_norm

        # automatic mixed precision
        # bigger changes need to make this work with dataparallel though (@autocast() decoration on each forward)
        pytorch_170 = int(str(torch.__version__)[2]) >= 7
        self.amp = args.amp and args.parallel != 'dp' and pytorch_170
        if self.amp:
            # only available in pytorch 1.7.0+
            self.scaler = torch.cuda.amp.GradScaler()
        else:
            self.scaler = None

        # save model architecture for reference
        self.save_architecture()
Пример #8
0
##################
## Specify data ##
##################

eval_loader, data_shape, cond_shape = get_data(args, eval_only=True)

####################
## Specify models ##
####################

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# conditional model
model = get_model(args, data_shape=data_shape, cond_shape=cond_shape)
if args.parallel == 'dp':
    model = DataParallelDistribution(model)
checkpoint = torch.load(path_check, map_location=torch.device(device))
model.load_state_dict(checkpoint['model'])
model = model.to(device)
model = model.eval()
print('Loaded weights for conditional model at {}/{} epochs'.format(
    checkpoint['current_epoch'], args.epochs))

# prior model
prior_model = get_model(prior_args,
                        data_shape=(data_shape[0],
                                    data_shape[1] // args.sr_scale_factor,
                                    data_shape[2] // args.sr_scale_factor))
if prior_args.parallel == 'dp':
    prior_model = DataParallelDistribution(prior_model)
prior_checkpoint = torch.load(path_prior_check,
Пример #9
0
    args = pickle.load(f)

##################
## Specify data ##
##################

eval_loader, data_shape, cond_shape = get_data(args, eval_only=True)

# Adjust args
args.batch_size = eval_args.batch_size

###################
## Specify model ##
###################

model = get_model(args, data_shape=data_shape, cond_shape=cond_shape)
if args.parallel == 'dp':
    model = DataParallelDistribution(model)
checkpoint = torch.load(path_check)
model.load_state_dict(checkpoint['model'])
print('Loaded weights for model at {}/{} epochs'.format(checkpoint['current_epoch'], args.epochs))

# Load checkpoint
exp.checkpoint_load('{}/check/'.format(more_args.model), device=more_args.new_device)

# modify model
if more_args.new_device is not None:
    exp.model.to(torch.device(more_args.new_device))

exp.eval_fn()
Пример #10
0
args.device = torch.device(device)

##################
## Specify data ##
##################

args.batch_size = 1
eval_loader, data_shape, cond_shape = get_data(args, eval_only=True)

###################
## Specify model ##
###################

model = get_model(args, data_shape=data_shape, cond_shape=cond_shape)
if args.parallel == 'dp':
    model = DataParallelDistribution(model)

checkpoint = torch.load(path_check, map_location=device)
model.load_state_dict(checkpoint['model'])
model = model.to(device)
model = model.eval()
print('Loaded weights for model at {}/{} epochs'.format(
    checkpoint['current_epoch'], args.epochs))

############
## Sample ##
############


def save_images(imgs, file_path, num_bits=args.num_bits, nrow=1):
    if not os.path.exists(os.path.dirname(file_path)):