コード例 #1
0
    def create(self):
        '''Build the Optimizer object from the properties

        Return
        ------
        tf.train.Optimizer
            Ready-made optimizer
        '''
        kind = self['kind']
        learning_rate = self.learning_rate
        name = self.get('name', 'optimizer')
        optimizer_cls = get_optimizer(kind)
        if kind in ['Momentum', 'RMSProp']:
            # only those two use momentum param
            try:
                momentum = self['momentum']
            except KeyError:
                raise ValueError(
                    'Momentum parameter is necessary for MomentumOptimizer')
            if kind == 'Momentum':
                if 'use_nesterov' in self:
                    use_nesterov = self['use_nesterov']
                else:
                    use_nesterov = False
                return optimizer_cls(learning_rate,
                                     momentum,
                                     use_nesterov,
                                     name=name)
            else:
                return optimizer_cls(learning_rate, momentum, name=name)
        else:
            return optimizer_cls(learning_rate, name=name)
def train(optims, max_epoch, policy, bsize, env, num_clicks, recom_number, max_length, origin_reward, capacity):
    outputdir="model_output"
    policy_new=os.path.join(outputdir, 'model_free_simple.pickle') 
    
    #weight = torch.FloatTensor(numlabel).fill_(1)
    optim_fn, optim_params=get_optimizer(optims)
    optimizer = optim_fn(filter(lambda p: p.requires_grad, policy.parameters()), **optim_params)
    
    n_epochs=max_epoch
    max_reward = 0
    epoch = 1
    best_model = None
    rewards = [origin_reward]
    while epoch <= n_epochs:
        _ = train_gen_pg_each(policy, env, epoch, optimizer, num_clicks, recom_number, max_length, bsize, total_size = capacity)
        print('saving policy at epoch {0}'.format(epoch))
        if not os.path.exists(outputdir):
            os.makedirs(outputdir)
        torch.save(policy, policy_new)
        #Eval the new policy
        _, mean_reward = Eval(policy_new)
        rewards.append(mean_reward)
        # save model        
        if mean_reward >= max_reward:
            best_model = policy
            max_reward = mean_reward
        epoch += 1
    return best_model, rewards, max_reward
コード例 #3
0
	def setup_optimizers(self, optimizer_name, lr, momentum=0.9, weight_decay=0, gradient_clipping=0):
		opt = util.get_optimizer(optimizer_name, lr, momentum)
		opt.setup(self)
		if weight_decay > 0:
			opt.add_hook(chainer.optimizer.WeightDecay(weight_decay))
		if gradient_clipping > 0:
			opt.add_hook(hooks.GradientClipping(gradient_clipping))
		self._optimizer = opt
def train_dis(optims, model, bsize, embed_dim, recom_length, trainSample,
              validSample, testSample):
    outputdir = "model_output"
    outputmodelname = "simu.model.pth"
    lrshrink = 5
    minlr = 1e-5

    #weight = torch.FloatTensor(numlabel).fill_(1)
    loss_fn = nn.NLLLoss()
    loss_fn.size_average = True
    loss_fn.to(device)
    optim_fn, optim_params = get_optimizer(optims)
    optimizer = optim_fn(filter(lambda p: p.requires_grad, model.parameters()),
                         **optim_params)

    n_epochs = 5
    inner_val_acc_best = -1e10
    inner_val_map_best = -1e10
    stop_training = False
    epoch = 1
    eval_type = 'valid'
    best_model = model
    while not stop_training and epoch <= n_epochs:
        train_acc, train_map = train_dis_each(model, epoch, trainSample,
                                              optimizer, bsize, embed_dim,
                                              recom_length, loss_fn, device)
        # Evaluate no eos
        eval_acc, eval_map = evaluate_discriminator(model, epoch, bsize,
                                                    recom_length, validSample,
                                                    testSample, device,
                                                    eval_type)  # save model
        if eval_type == 'valid' and epoch <= n_epochs:
            if eval_acc > inner_val_acc_best or eval_map > inner_val_map_best:
                best_model = model
                print('saving model at epoch {0}'.format(epoch))
                if not os.path.exists(outputdir):
                    os.makedirs(outputdir)
                torch.save(
                    model.state_dict(),
                    os.path.join(outputdir, 'irecGan_dis.' + outputmodelname))
                inner_val_acc_best = eval_acc
                inner_val_map_best = eval_map
                times_no_improvement = 0
            else:
                times_no_improvement += 1
                stop_training = adj_optim(optims, optimizer, minlr, lrshrink,
                                          stop_training, times_no_improvement)
        epoch += 1
    return best_model, inner_val_acc_best, inner_val_map_best
コード例 #5
0
def main():
    # Parse cmd line args
    parser = argparse.ArgumentParser()
    parser.add_argument('--model',
                        type=str,
                        required=True,
                        help='the name of the desired model configuration')
    parser.add_argument('--experiment',
                        type=str,
                        required=True,
                        help='the name of the desired experiment config')
    parser.add_argument('--batch_size',
                        type=int,
                        required=False,
                        default=8,
                        help='the train/test batch size')
    parser.add_argument('--test_split',
                        type=float,
                        required=False,
                        default=0.2,
                        help='the pct of data used for test')
    parser.add_argument('--epochs',
                        type=int,
                        required=False,
                        default=5,
                        help='the pct of data used for test')
    args = parser.parse_args()
    print(args.model)
    print(args.experiment)

    # Get model, optimizer and data
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("DEVICE:", device)
    model = util.models[args.model].to(device)
    optimizer = util.get_optimizer(model, args)

    train_loader, test_loader = \
        util.get_data_loaders(game_dataset.GameDataset(None, "./data/"),
                              args.batch_size,
                              args.test_split)

    # Train and evaluate
    for epoch in range(args.epochs):
        print("EPOCH", epoch)
        train(device, model, train_loader, args.test_split, optimizer, epoch,
              args)
        test(device, model, test_loader, args.test_split)
コード例 #6
0
    def evaluate_z_test(self, args, z_test, target, model, **kwargs):
        og_z_test = deepcopy(z_test)
        print('does requires grad come iwth?', z_test)


        loss = torch.ones(1, requires_grad=True).to(args.device)
        if self.logger.global_step % args.steps_per_z_test == 0: 
            # Will need to backprop into the input
            z_test = z_test.requires_grad_()

            # Optimizer with trainable input parameters
            optimizer = util.get_optimizer([z_test], args)
            
            # Freeze all layers
            for param in model.parameters():
                param.requires_grad = False

            # Use model in eval mode for dropout/BN layers to behave their best
            model.eval()

            with torch.set_grad_enabled(True):
                if args.use_intermediate_logits:
                    logits = model.forward(z_test).float()
                    probs = F.sigmoid(logits) 
                else:
                    probs = model.forward(z_test).float()

                loss = self.z_test_loss_fn(probs, target).mean()

                loss.backward()
                optimizer.step() 
                optimizer.zero_grad()
                
            # Unfreeze all layers
            for param in model.parameters():
                param.requires_grad = True

            model.train()
        
        print('sanity ztest', torch.sum(og_z_test), torch.sum(z_test))
        
        return z_test, loss #TODO! debug that z-test actually changes...
コード例 #7
0
    def __init__(self, cfg_module, writer):
        super().__init__(save_name=f'{cfg_module["network"]["name"]}_cls',
                         writer=writer)

        dataset_type = cfg_module["data"].get("dataset_type", "srproj")

        if dataset_type == "directory":
            # Setup Dataloader (put custom dataset such as "cls_dataset")
            self.trainloader = get_dataloader("training", cfg_module["data"],
                                              cls_dataset_directory)
            self.valloader = get_dataloader("validation", cfg_module["data"],
                                            cls_dataset_directory)
        assert (self.trainloader.dataset.n_classes == self.valloader.dataset.
                n_classes), "train/val dataset n_classes missmatch"
        cfg_module["network"].update(
            n_classes=self.trainloader.dataset.n_classes)
        cfg_module["metric"].update(
            n_classes=self.trainloader.dataset.n_classes)

        # Define Module type (for external usage)
        self.module_type = "classification"
        # Setup Model
        # Don't need to model.to(device)
        self.network = cls_network(cfg_module["network"])

        # Setup Loss
        self.loss = cls_loss(cfg_module["loss"])

        # Setup Metric
        self.metric = cls_metric(cfg_module["metric"])

        # Setup Optimizer and Scheduler
        self.optimizer = get_optimizer(cfg_module["optimizer"],
                                       self.network.parameters())
        self.scheduler_name = (cfg_module["scheduler"]["name"]
                               if cfg_module["scheduler"] else "")
        self.scheduler = get_scheduler(cfg_module["scheduler"], self.optimizer)

        # Load State
        self._load_state(cfg_module["load_state"])
コード例 #8
0
ファイル: test.py プロジェクト: sharonzhou/generator
def train_inverted_net(args):
    # Start by training an external model on samples of G(z) -> z inversion
    model = util.get_invert_model(args)

    model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    print(f'{args.invert_model} num params {count_parameters(model)}')

    generator = util.get_model(args)
    if generator is not None:
        generator = nn.DataParallel(generator, args.gpu_ids)
        generator = generator.to(args.device)
        print(f'{args.model} num params {count_parameters(generator)}')
    else:
        # Load saved pairings (ProGAN/StyleGAN)
        pairing_dir = '/deep/group/gen-eval/model-training/src/GAN_models/stylegan'
        pairing_path = f'{pairing_dir}/otavio_sampled_output/pairing.csv'
        pairings = pd.read_csv(pairing_path)

        num_pairings = len(pairings)
        noise_targets = pairings['noise']
        image_inputs = pairings['image']

    if 'BigGAN' in args.model:
        class_vector = one_hot_from_int(207, batch_size=args.batch_size)
        class_vector = torch.from_numpy(class_vector)
        class_vector = class_vector.cuda()

    # TODO: remove bc cant use gpu in laoder i don't think
    #loader = get_loader(args, phase='invert')

    #logger = TestLogger(args)
    #logger.log_hparams(args)

    criterion = torch.nn.MSELoss().to(args.device)
    optimizer = util.get_optimizer(model.parameters(), args)

    for i in range(args.num_invert_epochs):
        if generator is not None:
            noise_target = util.get_noise(args)

            image_input = generator.forward(noise_target).float()
            image_input = (image_input + 1.) / 2.
        else:
            # TODO: make into loader
            idx = i % num_pairings
            noise_target = np.load(f'{pairing_dir}/{noise_targets[idx]}')
            noise_target = torch.from_numpy(noise_target).float()
            print(f'noise target shape {noise_target.shape}')

            image_input = np.array(
                Image.open(f'{pairing_dir}/{image_inputs[idx]}'))
            image_input = torch.from_numpy(image_input / 255.)
            image_input = image_input.float().unsqueeze(0)
            image_input = image_input.permute(0, 3, 1, 2)

        noise_target = noise_target.cuda()
        image_input = image_input.cuda()

        with torch.set_grad_enabled(True):
            probs = model.forward(image_input)

            loss = torch.zeros(1, requires_grad=True).to(args.device)
            loss = criterion(probs, noise_target)
            print(f'iter {i}: loss = {loss}')

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        if i % 1 == 0:
            corres_image_input = image_input.detach().cpu()
            corres_np = util.convert_image_from_tensor(corres_image_input)

            # Run check - saving image
            if 'BigGAN' in args.model:
                predicted_image = generator.forward(probs, class_vector,
                                                    truncation).float()
            else:
                if generator is not None:
                    predicted_image = generator.forward(probs).float()

                    predicted_image = predicted_image.detach().cpu()
                    predicted_image = (predicted_image + 1) / 2.
                    predicted_np = util.convert_image_from_tensor(
                        predicted_image)

                    if len(predicted_np.shape) == 4:
                        predicted_np = predicted_np[0]
                        corres_np = corres_np[0]
                    visuals = util.concat_images([predicted_np, corres_np])
                    visuals_pil = Image.fromarray(visuals)
                    timestamp = datetime.now().strftime('%b%d_%H%M%S%f')
                    visuals_image_dir = f'predicted_inversion_images/{args.model}'
                    os.makedirs(visuals_image_dir, exist_ok=True)
                    visuals_image_path = f'{visuals_image_dir}/{timestamp}_{i}.png'
                    visuals_pil.save(visuals_image_path)

                    print(f'Saved {visuals_image_path}')
                else:
                    # Save noise vector - do forward separately in tf env
                    probs = probs.detach().cpu().numpy()
                    pred_noise_dir = f'predicted_inversion_noise/{args.model}'
                    os.makedirs(pred_noise_dir, exist_ok=True)

                    pred_noise_path = f'{pred_noise_dir}/{args.model}_noise_{i}.npy'
                    np.save(pred_noise_path, probs)

                    print(f'Saved {pred_noise_path}')

        if i % 1 == 0:
            corres_image_input = image_input.detach().cpu()
            corres_np = util.convert_image_from_tensor(corres_image_input)

            if len(corres_np.shape) == 4:
                corres_np = corres_np[0]

            corres_pil = Image.fromarray(corres_np)
            timestamp = datetime.now().strftime('%b%d_%H%M%S%f')
            corres_image_dir = f'generated_images/{args.model}'
            os.makedirs(corres_image_dir, exist_ok=True)
            corres_image_path = f'{corres_image_dir}/{timestamp}_{i}.png'
            corres_pil.save(corres_image_path)

    # saver = ModelSaver(args)
    global_step = args.num_invert_epochs
    ckpt_dict = {
        'ckpt_info': {
            'global_step': global_step
        },
        'model_name': model.module.__class__.__name__,
        'model_args': model.module.args_dict(),
        'model_state': model.to('cpu').state_dict(),
        'optimizer': optimizer.state_dict(),
    }

    ckpt_dir = os.path.join(args.save_dir, f'{args.model}')
    os.makedirs(ckpt_dir, exist_ok=True)
    ckpt_path = os.path.join(
        ckpt_dir, f'{args.invert_model}_step_{global_step}.pth.tar')
    torch.save(ckpt_dict, ckpt_path)
    print(f'Saved model to {ckpt_path}')

    import pdb
    pdb.set_trace()

    return model
コード例 #9
0
def train(args):
    # Get loader for outer loop training
    loader = get_loader(args)
    target_image_shape = loader.dataset.target_image_shape
    setattr(args, 'target_image_shape', target_image_shape)

    # Load model
    model_fn = models.__dict__[args.model]
    model = model_fn(**vars(args))
    model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Print model parameters
    print('Model parameters: name, size, mean, std')
    for name, param in model.named_parameters():
        print(name, param.size(), torch.mean(param), torch.std(param))

    # Get optimizer and loss
    parameters = model.parameters()
    optimizer = util.get_optimizer(parameters, args)
    loss_fn = util.get_loss_fn(args.loss_fn, args)

    z_loss_fn = util.get_loss_fn(args.loss_fn, args)

    # Get logger, saver
    logger = TrainLogger(args)
    saver = ModelSaver(args)

    print(f'Logs: {logger.log_dir}')
    print(f'Ckpts: {args.save_dir}')

    # Train model
    logger.log_hparams(args)
    batch_size = args.batch_size
    while not logger.is_finished_training():
        logger.start_epoch()

        for input_noise, target_image, mask, z_test_target, z_test in loader:
            logger.start_iter()

            if torch.cuda.is_available():
                input_noise = input_noise.to(args.device)  #.cuda()
                target_image = target_image.cuda()
                mask = mask.cuda()
                z_test = z_test.cuda()
                z_test_target = z_test_target.cuda()

            masked_target_image = target_image * mask
            obscured_target_image = target_image * (1.0 - mask)

            # Input is noise tensor, target is image
            model.train()
            with torch.set_grad_enabled(True):
                if args.use_intermediate_logits:
                    logits = model.forward(input_noise).float()
                    probs = F.sigmoid(logits)

                    # Debug logits and diffs
                    logger.debug_visualize(
                        [logits, logits * mask, logits * (1.0 - mask)],
                        unique_suffix='logits-train')
                else:
                    probs = model.forward(input_noise).float()

                # With backprop, calculate (1) masked loss, loss when mask is applied.
                # Loss is done elementwise without reduction, so must take mean after.
                # Easier for debugging.
                masked_probs = probs * mask
                masked_loss = torch.zeros(1,
                                          requires_grad=True).to(args.device)
                masked_loss = loss_fn(masked_probs, masked_target_image).mean()

                masked_loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            # Without backprop, calculate (2) full loss on the entire image,
            # And (3) the obscured loss, region obscured by mask.
            model.eval()
            with torch.no_grad():
                if args.use_intermediate_logits:
                    logits_eval = model.forward(input_noise).float()
                    probs_eval = F.sigmoid(logits_eval)

                    # Debug logits and diffs
                    logger.debug_visualize([
                        logits_eval, logits_eval * mask, logits_eval *
                        (1.0 - mask)
                    ],
                                           unique_suffix='logits-eval')
                else:
                    probs_eval = model.forward(input_noise).float()

                masked_probs_eval = probs_eval * mask
                masked_loss_eval = torch.zeros(1)
                masked_loss_eval = loss_fn(masked_probs_eval,
                                           masked_target_image).mean()

                full_loss_eval = torch.zeros(1)
                full_loss_eval = loss_fn(probs_eval, target_image).mean()

                obscured_probs_eval = probs_eval * (1.0 - mask)
                obscured_loss_eval = torch.zeros(1)
                obscured_loss_eval = loss_fn(obscured_probs_eval,
                                             obscured_target_image).mean()

            # With backprop on only the input z, (4) run one step of z-test and get z-loss
            z_optimizer = util.get_optimizer([z_test.requires_grad_()], args)
            with torch.set_grad_enabled(True):
                if args.use_intermediate_logits:
                    z_logits = model.forward(z_test).float()
                    z_probs = F.sigmoid(z_logits)
                else:
                    z_probs = model.forward(z_test).float()

                z_loss = torch.zeros(1, requires_grad=True).to(args.device)
                z_loss = z_loss_fn(z_probs, z_test_target).mean()

                z_loss.backward()
                z_optimizer.step()
                z_optimizer.zero_grad()

            if z_loss < args.max_z_test_loss:  # TODO: include this part into the metrics/saver stuff below
                # Save MSE on obscured region
                final_metrics = {'final/score': obscured_loss_eval.item()}
                logger._log_scalars(final_metrics)
                print('z loss', z_loss)
                print('Final MSE value', obscured_loss_eval)

            # TODO: Make a function for metrics - or at least make sure dict includes all possible best ckpt metrics
            metrics = {'masked_loss': masked_loss.item()}
            saver.save(logger.global_step,
                       model,
                       optimizer,
                       args.device,
                       metric_val=metrics.get(args.best_ckpt_metric, None))
            # Log both train and eval model settings, and visualize their outputs
            logger.log_status(
                inputs=input_noise,
                targets=target_image,
                probs=probs,
                masked_probs=masked_probs,
                masked_loss=masked_loss,
                probs_eval=probs_eval,
                masked_probs_eval=masked_probs_eval,
                obscured_probs_eval=obscured_probs_eval,
                masked_loss_eval=masked_loss_eval,
                obscured_loss_eval=obscured_loss_eval,
                full_loss_eval=full_loss_eval,
                z_target=z_test_target,
                z_probs=z_probs,
                z_loss=z_loss,
                save_preds=args.save_preds,
            )

            logger.end_iter()

        logger.end_epoch()

    # Last log after everything completes
    logger.log_status(
        inputs=input_noise,
        targets=target_image,
        probs=probs,
        masked_probs=masked_probs,
        masked_loss=masked_loss,
        probs_eval=probs_eval,
        masked_probs_eval=masked_probs_eval,
        obscured_probs_eval=obscured_probs_eval,
        masked_loss_eval=masked_loss_eval,
        obscured_loss_eval=obscured_loss_eval,
        full_loss_eval=full_loss_eval,
        z_target=z_test_target,
        z_probs=z_probs,
        z_loss=z_loss,
        save_preds=args.save_preds,
        force_visualize=True,
    )
コード例 #10
0
ファイル: train.py プロジェクト: yxliang/lca-code
def train(args):
    """Run training loop with the given args.

    The function consists of the following steps:
        1. Load model: gets the model from a checkpoint or from models/models.py.
        2. Load optimizer and learning rate scheduler.
        3. Get data loaders and class weights.
        4. Get loss functions: cross entropy loss and weighted loss functions.
        5. Get logger, evaluator, and saver.
        6. Run training loop, evaluate and save model periodically.
    """

    model_args = args.model_args
    logger_args = args.logger_args
    optim_args = args.optim_args
    data_args = args.data_args
    transform_args = args.transform_args

    task_sequence = TASK_SEQUENCES[data_args.task_sequence]

    # Get model
    if model_args.ckpt_path:
        model_args.pretrained = False
        model, ckpt_info = ModelSaver.load_model(model_args.ckpt_path,
                                                 args.gpu_ids, model_args,
                                                 data_args)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[model_args.model]
        model = model_fn(task_sequence, model_args)
        if model_args.hierarchy:
            model = models.HierarchyWrapper(model, task_sequence)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    optimizer = util.get_optimizer(model.parameters(), optim_args)
    lr_scheduler = util.get_scheduler(optimizer, optim_args)
    if model_args.ckpt_path:
        ModelSaver.load_optimizer(model_args.ckpt_path, args.gpu_ids,
                                  optimizer, lr_scheduler)

    # Get loaders and class weights
    train_csv_name = 'train'
    if data_args.uncertain_map_path is not None:
        train_csv_name = data_args.uncertain_map_path
    #TODO: Remove this when we decide which transformation to use in the end
    #transforms_imgaug = ImgAugTransform()
    train_loader = get_loader(data_args,
                              transform_args,
                              train_csv_name,
                              task_sequence,
                              data_args.su_train_frac,
                              data_args.nih_train_frac,
                              data_args.pocus_train_frac,
                              data_args.tcga_train_frac,
                              0,
                              0,
                              args.batch_size,
                              frontal_lateral=model_args.frontal_lateral,
                              is_training=True,
                              shuffle=True,
                              transform=model_args.transform,
                              normalize=model_args.normalize)
    eval_loaders = get_eval_loaders(data_args,
                                    transform_args,
                                    task_sequence,
                                    args.batch_size,
                                    frontal_lateral=model_args.frontal_lateral,
                                    normalize=model_args.normalize)
    class_weights = train_loader.dataset.class_weights
    print(" class weights:")
    print(class_weights)

    # Get loss functions
    uw_loss_fn = get_loss_fn('cross_entropy',
                             args.device,
                             model_args.model_uncertainty,
                             args.has_tasks_missing,
                             class_weights=class_weights)

    w_loss_fn = get_loss_fn('weighted_loss',
                            args.device,
                            model_args.model_uncertainty,
                            args.has_tasks_missing,
                            mask_uncertain=False,
                            class_weights=class_weights)

    # Get logger, evaluator and saver
    logger = TrainLogger(logger_args, args.start_epoch,
                         args.num_epochs, args.batch_size,
                         len(train_loader.dataset), args.device)

    eval_args = {}
    eval_args['num_visuals'] = logger_args.num_visuals
    eval_args['iters_per_eval'] = logger_args.iters_per_eval
    eval_args['has_missing_tasks'] = args.has_tasks_missing
    eval_args['model_uncertainty'] = model_args.model_uncertainty
    eval_args['class_weights'] = class_weights
    eval_args['max_eval'] = logger_args.max_eval
    eval_args['device'] = args.device
    eval_args['optimizer'] = args.optimizer
    evaluator = get_evaluator('classification', eval_loaders, logger,
                              eval_args)

    print("Eval Loaders: %d" % len(eval_loaders))
    saver = ModelSaver(**vars(logger_args))

    metrics = None
    lr_step = 0
    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets, info_dict in train_loader:

            logger.start_iter()

            # Evaluate and save periodically
            metrics, curves = evaluator.evaluate(model, args.device,
                                                 logger.global_step)
            logger.plot_metrics(metrics)
            metric_val = metrics.get(logger_args.metric_name, None)

            assert logger.global_step % logger_args.iters_per_eval != 0 or metric_val is not None
            saver.save(logger.global_step,
                       logger.epoch,
                       model,
                       optimizer,
                       lr_scheduler,
                       args.device,
                       metric_val=metric_val)
            lr_step = util.step_scheduler(
                lr_scheduler,
                metrics,
                lr_step,
                best_ckpt_metric=logger_args.metric_name)

            # Input: [batch_size, channels, width, height]

            with torch.set_grad_enabled(True):

                logits = model.forward(inputs.to(args.device))

                unweighted_loss = uw_loss_fn(logits, targets.to(args.device))

                weighted_loss = w_loss_fn(logits, targets.to(
                    args.device)) if w_loss_fn else None

                logger.log_iter(inputs, logits, targets, unweighted_loss,
                                weighted_loss, optimizer)

                optimizer.zero_grad()
                if args.loss_fn == 'weighted_loss':
                    weighted_loss.backward()
                else:
                    unweighted_loss.backward()
                optimizer.step()

            logger.end_iter()

        logger.end_epoch(metrics, optimizer)
コード例 #11
0
def main():
    global best_acc
    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    # load data
    transformations = get_transforms(input_size=args.image_size,
                                     test_size=args.image_size)
    # train data
    train_set = CGPIM_Data(root=args.train_txt_path,
                           transform=transformations['val_train'],
                           isTrain=True)
    train_loader = data.DataLoader(train_set,
                                   batch_size=args.batch_size,
                                   shuffle=True)
    # val data
    val_set = CGPIM_Data(root=args.val_txt_path,
                         transform=transformations['val_test'],
                         isTrain=False)
    val_loader = data.DataLoader(val_set,
                                 batch_size=args.batch_size,
                                 shuffle=False)

    # define model
    model = ResNeXt(2, 3, [3, 4, 6, 3], 2)
    model.cuda()

    # define loss function and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = get_optimizer(model, args)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.2,
                                                           patience=5,
                                                           verbose=False)

    # load checkpoint
    start_epoch = args.start_epoch
    for epoch in range(start_epoch, args.epochs):

        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, optimizer.param_groups[0]['lr']))
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch)
        test_loss, val_acc = val(val_loader, model, criterion, epoch)
        scheduler.step(test_loss)
        print('train_loss: %.3f, val_loss:%.3f, train_acc:%.3f, val_acc:%.3f' %
              (train_loss, test_loss, train_acc, val_acc))

        # save_model
        is_best = val_acc >= best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint(
            {
                'fold': 0,
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'train_acc': train_acc,
                'acc': val_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            single=True,
            checkpoint=args.checkpoint)

    print("best acc = ", best_acc)
コード例 #12
0
ファイル: trainer.py プロジェクト: satishwarkedas/DisExtract
    dis_net = DisSent(config_dis_model)
    logger.info(dis_net)
else:
    # if starting epoch is not 1, we resume training
    # 1. load in model
    # 2. resume with the previous learning rate
    model_path = pjoin(params.outputdir, params.outputmodelname + ".pickle")  # this is the best model
    # this might have conflicts with gpu_idx...
    dis_net = torch.load(model_path)

# loss
loss_fn = nn.CrossEntropyLoss()
loss_fn.size_average = False

# optimizer
optim_fn, optim_params = get_optimizer(params.optimizer)
optimizer = optim_fn(dis_net.parameters(), **optim_params)

if params.cur_epochs != 1:
    optimizer.param_groups[0]['lr'] = params.cur_lr

# cuda by default
dis_net.cuda()
loss_fn.cuda()

"""
TRAIN
"""
val_acc_best = -1e10 if params.cur_epochs == 1 else params.cur_valid
adam_stop = False
stop_training = False
コード例 #13
0
    def __init__(self, env, build_agent, task_index, writer, args):
        print '* A3C arguments:'
        vargs = vars(args)
        for k in sorted(vargs.keys()):
            print k, vargs[k]

        self.env = env
        self.task_index = task_index
        self.is_chief = task_index == 0

        # build compute graphs
        worker_device = '/job:worker/task:{}'.format(task_index)
        # on parameter server and locally
        with tf.device(
                tf.train.replica_device_setter(ps_tasks=1,
                                               worker_device=worker_device)):
            with tf.variable_scope('global'):
                # clone of the model for parameter server
                build_agent(env.spec)
                global_variables = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES,
                    tf.get_variable_scope().name)
                self.global_tick = tf.get_variable(
                    'global_tick', [],
                    'int32',
                    trainable=False,
                    initializer=tf.zeros_initializer())
                # shared the optimizer
                if args.shared:
                    optimizer = get_optimizer(args.optimizer,
                                              args.learning_rate,
                                              args.momentum)

        # local only
        with tf.device(worker_device):
            with tf.variable_scope('local'):
                self.agent = build_agent(env.spec)
                assert isinstance(self.agent,
                                  (ActorCriticAgent, StatefulActorCriticAgent))
                self.use_history = isinstance(self.agent,
                                              StatefulActorCriticAgent)
                local_variables = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES,
                    tf.get_variable_scope().name)

            self.local_step = 0
            # copy parameters from `global/` to `local/``
            self.sync_op = tf.group(*[
                v1.assign(v2)
                for v1, v2 in zip(local_variables, global_variables)
            ])

            # define objectives
            # input variables
            self.actions_taken_ph = tf.placeholder('int32')
            self.target_value_ph = tf.placeholder('float')
            self.advantage_ph = tf.placeholder('float')

            # entropy regularizer to encourage action diversity (naive exploration)
            log_action_probs = tf.nn.log_softmax(self.agent.action_logits)
            action_probs = tf.nn.softmax(self.agent.action_logits)

            self.action_entropy = -tf.reduce_sum(
                action_probs * log_action_probs)
            taken_action_logits = vector_slice(log_action_probs,
                                               self.actions_taken_ph)
            # taken_action_logits = mask_slice(log_action_probs, actions_taken_ph)

            # objective for value estimation
            self.value_objective = tf.reduce_sum(
                tf.square(self.target_value_ph - self.agent.state_values))

            # objective for computing policy gradient
            self.policy_objective = tf.reduce_sum(taken_action_logits *
                                                  self.advantage_ph)

            # total objective
            # maximize policy objective
            # minimize value objective
            # maximize action entropy
            self.objective = -self.policy_objective + args.value_objective_coeff * self.value_objective - args.action_entropy_coeff * self.action_entropy

            grads = tf.gradients(self.objective, local_variables)
            # apply gradients to the global parameters
            batch_len = tf.shape(self.actions_taken_ph)[0]
            per_batch_len = 1. / tf.to_float(batch_len)
            inc_tick = self.global_tick.assign_add(batch_len)

            self.reward_gamma = args.reward_gamma
            self.return_lambda = args.return_lambda
            self.return_n_step = args.return_n_step
            self.advantage_lambda = args.advantage_lambda
            self.advantage_n_step = args.advantage_n_step

            self.writer = writer

            # summaries
            self.episode_len_ph = tf.placeholder('float', name='episode_len')
            self.episode_reward_ph = tf.placeholder('float',
                                                    name='episode_reward')
            self.ticks_per_second_ph = tf.placeholder('float',
                                                      name='ticks_per_second')
            self.steps_per_second_ph = tf.placeholder('float',
                                                      name='steps_per_second')

            self.per_episode_summary = tf.summary.merge([
                tf.summary.scalar('episodic/reward', self.episode_reward_ph),
                tf.summary.scalar('episodic/length', self.episode_len_ph),
                tf.summary.scalar('episodic/reward_per_tick',
                                  self.episode_reward_ph /
                                  self.episode_len_ph),
            ])

            norm = tf.global_norm(grads)
            var_norm = tf.global_norm(local_variables)

            # local optimizer
            if not args.shared:
                optimizer = get_optimizer(args.optimizer, args.learning_rate,
                                          args.momentum)

            if args.no_grad_clip:
                normed_grads = grads
                clipped_norm = norm
            else:
                # gradient clipping
                normed_grads, _ = tf.clip_by_global_norm(
                    grads, args.clip_norm, norm)
                clipped_norm = tf.minimum(args.clip_norm, norm)

            self.update_op = tf.group(
                optimizer.apply_gradients(zip(normed_grads, global_variables)),
                inc_tick)

            self.summary_interval = args.summary_interval
            if self.is_chief:
                print '* gradients'
                grad_summaries = []
                for g, v in zip(normed_grads, global_variables):
                    grad_summaries.append(
                        tf.summary.histogram('gradients/%s' % v.name, g))
                    print '%s -> %s' % (g.name, v.name)

                self.per_step_summary = tf.summary.merge(grad_summaries + [
                    tf.summary.scalar('model/objective', self.objective *
                                      per_batch_len),
                    tf.summary.scalar('model/state_value_objective',
                                      self.value_objective * per_batch_len),
                    tf.summary.scalar('model/policy_objective',
                                      self.policy_objective * per_batch_len),
                    tf.summary.scalar(
                        'model/action_perplexity',
                        tf.exp(self.action_entropy * per_batch_len)),
                    tf.summary.scalar('model/gradient_norm', norm),
                    tf.summary.scalar('model/clipped_gradient_norm',
                                      clipped_norm),
                    tf.summary.scalar('model/var_norm', var_norm),
                    tf.summary.scalar('chief/steps_per_second',
                                      self.steps_per_second_ph),
                    tf.summary.scalar('chief/ticks_per_second',
                                      self.ticks_per_second_ph),
                ])

            self.n_update_ticks = None if args.n_update_ticks == 0 else args.n_update_ticks

            self.step_start_at = None

            # process returns
            if args.return_eval == 'td':
                self.process_returns = lambda rewards, values, bootstrap_value: td_return(
                    rewards, values, self.reward_gamma, bootstrap_value)
            elif args.return_eval == 'mc':
                self.process_returns = lambda rewards, values, bootstrap_value: mc_return(
                    rewards, self.reward_gamma, bootstrap_value)
            elif args.return_eval == 'n-step':
                self.process_returns = lambda rewards, values, bootstrap_value: n_step_return(
                    rewards, values, self.reward_gamma, bootstrap_value, self.
                    return_n_step)
            else:
                self.process_returns = lambda rewards, values, bootstrap_value: lambda_return(
                    rewards, values, self.reward_gamma, self.return_lambda,
                    bootstrap_value)

            # process advantages
            if args.advantage_eval == 'td':
                self.process_advantages = lambda rewards, values, bootstrap_value: td_return(
                    rewards, values, self.reward_gamma, bootstrap_value
                ) - values
            elif args.advantage_eval == 'mc':
                self.process_advantages = lambda rewards, values, bootstrap_value: mc_return(
                    rewards, self.reward_gamma, bootstrap_value) - values
            elif args.advantage_eval == 'n-step':
                self.process_advantages = lambda rewards, values, bootstrap_value: n_step_return(
                    rewards, values, self.reward_gamma, bootstrap_value, self.
                    advantage_n_step) - values
            else:
                self.process_advantages = lambda rewards, values, bootstrap_value: lambda_advantage(
                    rewards, values, self.reward_gamma, self.advantage_lambda,
                    bootstrap_value)
コード例 #14
0
ファイル: train.py プロジェクト: stanfordmlgroup/CheXaid
def train(args):
    """Run training loop with the given args.

    The function consists of the following steps:
        1. Load model: gets the model from a checkpoint or from models/models.py.
        2. Load optimizer and learning rate scheduler.
        3. Get data loaders and class weights.
        4. Get loss functions: cross entropy loss and weighted loss functions.
        5. Get logger, evaluator, and saver.
        6. Run training loop, evaluate and save model periodically.
    """
    model_args = args.model_args
    logger_args = args.logger_args
    optim_args = args.optim_args
    data_args = args.data_args
    transform_args = args.transform_args

    task_sequence = TASK_SEQUENCES[data_args.task_sequence]
    print('gpus: ', args.gpu_ids)
    # Get model
    if model_args.ckpt_path:
        model_args.pretrained = False
        model, ckpt_info = ModelSaver.load_model(model_args.ckpt_path, args.gpu_ids, model_args, data_args)
        if not logger_args.restart_epoch_count:
            args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[model_args.model]
        model = model_fn(task_sequence, model_args)
        num_covars = len(model_args.covar_list.split(';'))
        model.transform_model_shape(len(task_sequence), num_covars)
        if model_args.hierarchy:
            model = models.HierarchyWrapper(model, task_sequence)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    optimizer = util.get_optimizer(model.parameters(), optim_args)
    lr_scheduler = util.get_scheduler(optimizer, optim_args)

    # The optimizer is loaded from the ckpt if one exists and the new model
    # architecture is the same as the old one (classifier is not transformed).
    if model_args.ckpt_path and not model_args.transform_classifier:
        ModelSaver.load_optimizer(model_args.ckpt_path, args.gpu_ids, optimizer, lr_scheduler)

    # Get loaders and class weights
    train_csv_name = 'train'
    if data_args.uncertain_map_path is not None:
        train_csv_name = data_args.uncertain_map_path

    # Put all CXR training fractions into one dictionary and pass it to the loader
    cxr_frac = {'pocus': data_args.pocus_train_frac, 'hocus': data_args.hocus_train_frac,
                'pulm': data_args.pulm_train_frac}
    train_loader = get_loader(data_args,
                              transform_args,
                              train_csv_name,
                              task_sequence,
                              data_args.su_train_frac,
                              data_args.nih_train_frac,
                              cxr_frac,
                              data_args.tcga_train_frac,
                              args.batch_size,
                              frontal_lateral=model_args.frontal_lateral,
                              is_training=True,
                              shuffle=True,
                              covar_list=model_args.covar_list,
                              fold_num=data_args.fold_num)
    eval_loaders = get_eval_loaders(data_args,
                                    transform_args,
                                    task_sequence,
                                    args.batch_size,
                                    frontal_lateral=model_args.frontal_lateral,
                                    covar_list=model_args.covar_list,
                                    fold_num=data_args.fold_num)
    class_weights = train_loader.dataset.class_weights

    # Get loss functions
    uw_loss_fn = get_loss_fn(args.loss_fn, args.device, model_args.model_uncertainty,
        args.has_tasks_missing, class_weights=class_weights)
    w_loss_fn = get_loss_fn('weighted_loss', args.device, model_args.model_uncertainty,
        args.has_tasks_missing, class_weights=class_weights)

    # Get logger, evaluator and saver
    logger = TrainLogger(logger_args, args.start_epoch, args.num_epochs, args.batch_size,
        len(train_loader.dataset), args.device, normalization=transform_args.normalization)
    
    eval_args = {}
    eval_args['num_visuals'] = logger_args.num_visuals
    eval_args['iters_per_eval'] = logger_args.iters_per_eval
    eval_args['has_missing_tasks'] = args.has_tasks_missing
    eval_args['model_uncertainty'] = model_args.model_uncertainty
    eval_args['class_weights'] = class_weights
    eval_args['max_eval'] = logger_args.max_eval
    eval_args['device'] = args.device
    eval_args['optimizer'] = optimizer
    evaluator = get_evaluator('classification', eval_loaders, logger, eval_args)

    print("Eval Loaders: %d" % len(eval_loaders))
    saver = ModelSaver(**vars(logger_args))

    metrics = None
    lr_step = 0
    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets, info_dict, covars in train_loader:
            logger.start_iter()

            # Evaluate and save periodically
            metrics, curves = evaluator.evaluate(model, args.device, logger.global_step)
            logger.plot_metrics(metrics)
            metric_val = metrics.get(logger_args.metric_name, None)
            assert logger.global_step % logger_args.iters_per_eval != 0 or metric_val is not None
            saver.save(logger.global_step, logger.epoch, model, optimizer, lr_scheduler, args.device,
                       metric_val=metric_val, covar_list=model_args.covar_list)
            lr_step = util.step_scheduler(lr_scheduler, metrics, lr_step, best_ckpt_metric=logger_args.metric_name)

            # Input: [batch_size, channels, width, height]

            with torch.set_grad_enabled(True):
            # with torch.autograd.set_detect_anomaly(True):

                logits = model.forward([inputs.to(args.device), covars])

                # Scale up TB so that it's loss is counted for more if upweight_tb is True.
                if model_args.upweight_tb is True:
                    tb_targets = targets.narrow(1, 0, 1)
                    findings_targets = targets.narrow(1, 1, targets.shape[1] - 1)
                    tb_targets = tb_targets.repeat(1, targets.shape[1] - 1)
                    new_targets = torch.cat((tb_targets, findings_targets), 1)

                    tb_logits = logits.narrow(1, 0, 1)
                    findings_logits = logits.narrow(1, 1, logits.shape[1] - 1)
                    tb_logits = tb_logits.repeat(1, logits.shape[1] - 1)
                    new_logits = torch.cat((tb_logits, findings_logits), 1)
                else:
                    new_logits = logits
                    new_targets = targets

                    
                unweighted_loss = uw_loss_fn(new_logits, new_targets.to(args.device))

                weighted_loss = w_loss_fn(logits, targets.to(args.device)) if w_loss_fn else None

                logger.log_iter(inputs, logits, targets, unweighted_loss, weighted_loss, optimizer)

                optimizer.zero_grad()
                if args.loss_fn == 'weighted_loss':
                    weighted_loss.backward()
                else:
                    unweighted_loss.backward()
                optimizer.step()

            logger.end_iter()

        logger.end_epoch(metrics, optimizer)
コード例 #15
0
    def __init__(self, vocabulary_size, embedding_size, validation_set,
                 **kwargs):
        '''
        Parameters
        ----------
        vocabulary_size :   int
                            Number of classes to predict from (the size of the
                            considerer vocabulary)
        embedding_size  :   int
                            Dimensionality of the space to project the words
                            onto. Length of the vectors representing a word
        validation_set  :   Set of word ids to check the similarity for

        learning_rate   :   float
                            Optimizer learning rate
        optimizer       :   str
                            Name of the tf optimizer to use (e.g.  "GradientDescent")
        noise_samples   :   int
                            Number of noise samples for NCE sampling
        '''
        ###########################
        #  Extract relevant args  #
        ###########################
        optimizer_cls = get_optimizer(
            kwargs.get('optimizer', 'GradientDescent'))
        learning_rate = kwargs.get('learning_rate', 1)
        noise_samples = kwargs.get('noise_samples', 64)

        ############################################
        #  Input word id + target context word id  #
        ############################################
        self.target_word_id = tf.placeholder(tf.int32,
                                             shape=(None, ),
                                             name='target')
        self.target_context_id = tf.placeholder(tf.int32,
                                                shape=(None, 1),
                                                name='target')

        ##################
        #  Hidden layer  #
        ##################
        W_context, _ = get_weights_and_bias((vocabulary_size, embedding_size))
        a_h = tf.nn.embedding_lookup(W_context, self.target_word_id)

        ##################
        #  Output layer  #
        ##################
        # notice that - strangely - the weights matrix must be transposed from
        # what you would use for multiplying a_h * W. This seems to be a quirk
        # of tf nce_loss
        initializer = tf.random_normal_initializer(stddev=1 /
                                                   np.sqrt(embedding_size))
        W_target, b_target = get_weights_and_bias(
            (vocabulary_size, embedding_size),
            shape_b=(vocabulary_size, ),
            initializer_w=initializer)

        ##########
        #  Loss  #
        ##########
        with tf.variable_scope('loss'):
            self.loss = tf.reduce_mean(
                tf.nn.nce_loss(weights=W_target,
                               biases=b_target,
                               labels=self.target_context_id,
                               inputs=a_h,
                               num_sampled=noise_samples,
                               num_classes=vocabulary_size))

        #########################
        #  TensorBoard logging  #
        #########################
        with tf.variable_scope('summary'):
            tf.summary.scalar('loss', self.loss)
            self.merged = tf.summary.merge_all()

        self.train_step = optimizer_cls(learning_rate).minimize(self.loss)

        ########################################
        #  Accuracy for a fixed set of words.  #
        ########################################
        # Compute the cosine similarity between minibatch examples and all
        # embeddings. Copypasta from stackoverflow
        norm = tf.sqrt(tf.reduce_sum(tf.square(W_context), 1, keep_dims=True))
        normalized_embeddings = W_context / norm
        valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings,
                                                  validation_set)
        self.similarity = tf.matmul(valid_embeddings,
                                    normalized_embeddings,
                                    transpose_b=True)
        # get the 8 closeset, because the first closest is always the word
        # itself.
        self.closest_words = tf.nn.top_k(self.similarity, 8).indices[:, 1:]
def pgtrain(optims_gen,
            optims_dis,
            generator,
            agent,
            discriminator,
            bsize,
            embed_dim,
            trainSample,
            validSample,
            testSample,
            val_acc_best,
            val_preck_best,
            val_loss_best,
            action_num,
            max_length,
            recom_length,
            gen_ratio=0.1,
            n_epochs=5,
            write_item='click_gen.txt',
            write_target='tar_gen.txt',
            write_reward='reward_gen.txt',
            write_action='action_gen.txt',
            plot_fig=True,
            pretrain=False):
    outputdir = "model_output"
    outputmodelname = "simu.model.pth"
    lrshrink = 5
    minlr = 1e-5

    #Evaluation loss functions
    loss_fn_target = nn.CrossEntropyLoss()
    loss_fn_reward = nn.BCEWithLogitsLoss()
    loss_fn_target.size_average = True
    loss_fn_target.to(device)
    loss_fn_reward.size_average = True
    loss_fn_reward.to(device)

    inner_val_preck_best = val_preck_best
    inner_val_acc_best = val_acc_best
    inner_loss_best = val_loss_best
    epoch = 1
    eval_type = 'valid'
    g_step = 1
    d_step = 1
    evalacc_all = [val_acc_best]
    evalpreck_all = [val_preck_best]
    #Define the optimizer
    optim_fn_gen, optim_params_gen = get_optimizer(optims_gen)
    optim_fn_dis, optim_params_dis = get_optimizer(optims_dis)
    optimizer_dis = optim_fn_dis(
        filter(lambda p: p.requires_grad, discriminator.parameters()),
        **optim_params_dis)
    params_agent = list(agent.parameters())
    params_usr = list(generator.parameters())
    optimizer_agent = optim_fn_gen(
        filter(lambda p: p.requires_grad, params_agent), **optim_params_gen)
    optimizer_usr = optim_fn_gen(filter(lambda p: p.requires_grad, params_usr),
                                 **optim_params_gen)
    while epoch <= n_epochs:
        print('\nAdversarial Policy Gradient Training!')
        # Select subset of trainSample
        subnum = 8000
        for i in range(g_step):
            print('G-step')
            if pretrain:
                print('For Pretraining')
                _ = train_gen_pg_each(generator, agent,
                                      discriminator, epoch, trainSample,
                                      trainSample.length(), optimizer_agent,
                                      optimizer_usr, bsize, embed_dim,
                                      recom_length, max_length, action_num,
                                      device, 0, pretrain)
            else:
                print('For Policy Gradient Update')
                #shuffle_index=np.random.permutation(origin.length())
                _ = train_gen_pg_each(generator, agent, discriminator, epoch,
                                      trainSample, subnum, optimizer_agent,
                                      optimizer_usr, bsize, embed_dim,
                                      recom_length, max_length, action_num,
                                      device, 0.1, pretrain)

        # save model
        # Evaluate without eos, no eos input
        print("Agent evaluation!")
        eval_acc, eval_preck = evaluate_agent(agent,
                                              epoch,
                                              bsize,
                                              recom_length,
                                              validSample,
                                              testSample,
                                              device,
                                              eval_type='valid')
        print("User model evaluation!")
        _ = evaluate_user(generator, epoch, bsize, recom_length, validSample,
                          testSample, loss_fn_target, loss_fn_reward, device,
                          eval_type)
        print("Interaction evaluation!")
        _ = evaluate_interaction(
            (generator, agent), epoch, bsize, recom_length, validSample,
            testSample, loss_fn_target, loss_fn_reward, device, eval_type)

        evalacc_all.append(eval_acc)
        evalpreck_all.append(eval_preck)
        if eval_type == 'valid' and epoch <= n_epochs:
            print('saving model at epoch {0}'.format(epoch))
            if not os.path.exists(outputdir):
                os.makedirs(outputdir)
            torch.save(
                agent.state_dict(),
                os.path.join(outputdir, 'irecGan_agent3.' + outputmodelname))
            torch.save(
                generator.state_dict(),
                os.path.join(outputdir, 'irecGan_gen3.' + outputmodelname))

            inner_val_acc_best = eval_acc
            inner_val_preck_best = eval_preck

        if not pretrain:
            '''
            #Adjust the reward prediction
            print('Reward Adjust')
            trainSample_rewd, validSample_rewd, testSample_rewd=sampleSplit(trainindex, validindex, testindex, Seqlist, numlabel, recom_length)
            _ = train_user_pred(optims_dis, generator, bsize, embed_dim, recom_length + 1, trainSample_rewd, validSample_rewd, testSample_rewd, 'generator with rec', None, None, None, None, only_rewards = True, n_epochs=1)
            #Enable full model training
            for name, param in generator.named_parameters():
                if 'embedding' in name or 'encoder' or 'enc2out' in name:
                    param.requires_grad = True
            '''
            print('\nD-step')
            #Discriminator trainging
            for i in range(d_step):
                shutil.copy('click_gen_real.txt', write_item)
                shutil.copy('reward_gen_real.txt', write_reward)
                shutil.copy('tar_gen_real.txt', write_target)
                shutil.copy('action_gen_real.txt', write_action)
                _, _, _, _ = gen_fake(generator, agent, trainSample, bsize,
                                      embed_dim, device, write_item,
                                      write_target, write_reward, write_action,
                                      action_num, max_length, recom_length)
                clicklist, _ = ReadSeq(write_item, write_reward, write_action,
                                       write_target)
                trainindex_dis, validindex_dis, testindex_dis = split_index(
                    0.7, 0.1, len(clicklist), True)  #Shuffle the index
                trainSample_dis, validSample_dis, testSample_dis = sampleSplit(
                    trainindex_dis, validindex_dis, testindex_dis, clicklist,
                    2, recom_length, 'dis')

                discriminator, _, _ = train_dis(optims_dis, discriminator,
                                                bsize, embed_dim, recom_length,
                                                trainSample_dis,
                                                validSample_dis,
                                                testSample_dis)
        epoch += 1

    if plot_fig == True:
        save_plot(n_epochs, 1, evalacc_all, 'pg_accuracy6.png')
        save_plot(n_epochs, 1, evalpreck_all, 'pg_map6.png')
    return inner_val_acc_best, inner_val_preck_best
コード例 #17
0
ファイル: test.py プロジェクト: sharonzhou/generator
def test(args):
    # Get loader for z-test
    loader = get_loader(args, phase='test')
    batch_size = args.batch_size
    class_vector = None

    # TODO: make into function that takes in args.model and returns the pretrained model
    #       and also consider whether it's class conditional and what kind of class conditional (how many classes) -> probably just imagenet now, actually maybe cifar-10 too
    #       and also consider add truncation sampling as option too - this should return model, z_test noise vec, and class_vec (optionally)
    if args.ckpt_path and not args.use_pretrained:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
    else:
        if 'BigGAN' in args.model:
            num_params = int(''.join(filter(str.isdigit, args.model)))

            if 'perturbation' in args.loss_fn:
                # Use custom BigGAN with Perturbation Net wrapper
                model = models.BigGANPerturbationNet.from_pretrained(
                    f'biggan-deep-{num_params}')
            else:
                # Use pretrained BigGAN from package
                model = BigGAN.from_pretrained(f'biggan-deep-{num_params}')

            z_test = truncated_noise_sample(truncation=args.truncation,
                                            batch_size=batch_size)
            z_test = torch.from_numpy(z_test)

            # Get class conditional label
            # 981 is baseball player
            # 207 is golden retriever
            # TODO: Conditional generation only
            class_vector = one_hot_from_int(207, batch_size=batch_size)
            class_vector = torch.from_numpy(class_vector)

        elif 'WGAN-GP' in args.model:
            generator_path = "/deep/group/gen-eval/model-training/src/GAN_models/improved-wgan-pytorch/experiments/exp4_wgan_gp/generator.pt"
            model = torch.load(generator_path)
            z_test = torch.randn(batch_size, 128)

        elif 'BEGAN' in args.model:
            generator_path = "/deep/group/gen-eval/model-training/src/GAN_models/BEGAN-pytorch/trained_models/64/models/gen_97000.pth"
            model = models.BEGANGenerator()
            model.load_state_dict(torch.load(generator_path))

            z_test = np.random.uniform(-1, 1, size=(batch_size, 64))
            z_test = torch.FloatTensor(z_test)

    # Freeze model instead of using .eval()
    for param in model.parameters():
        param.requires_grad = False

    # If using perturbation net, learn perturbation layers
    if 'perturbation' in args.loss_fn:
        trainable_params = []
        for name, param in model.named_parameters():
            if 'perturb' in name:
                param.requires_grad = True
                trainable_params.append(param)
        print(f'Number of trainable params: {len(trainable_params)}')

    model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)

    # Loss functions
    if 'mse' in args.loss_fn:
        pixel_criterion = torch.nn.MSELoss().to(args.device)
    else:
        pixel_criterion = torch.nn.L1Loss().to(args.device)

    if 'perceptual' in args.loss_fn:
        # Combination pixel-perceptual loss - Sec 3.2. By default, uses pixel L1.
        perceptual_criterion = torch.nn.L1Loss().to(args.device)
        perceptual_loss_weight = args.perceptual_loss_weight

        vgg_feature_extractor = models.VGGFeatureExtractor().to(args.device)
        vgg_feature_extractor.eval()
    elif 'perturbation' in args.loss_fn:
        # Perturbation network R. By default, uses pixel L1.
        # Sec 3.3: http://ganpaint.io/Bau_et_al_Semantic_Photo_Manipulation_preprint.pdf
        reg_loss_weight = args.reg_loss_weight

    # z_loss_fn = util.get_loss_fn(args.loss_fn, args)
    max_z_test_loss = 100.  # TODO: actually put max value possible here

    # Get logger, saver
    logger = TestLogger(args)
    # saver = ModelSaver(args) TODO: saver for perturbation network R

    print(f'Logs: {logger.log_dir}')
    print(f'Ckpts: {args.save_dir}')

    # Run z-test in batches
    logger.log_hparams(args)

    while not logger.is_finished_training():
        logger.start_epoch()

        for _, z_test_target, mask in loader:
            logger.start_iter()
            if torch.cuda.is_available():
                mask = mask.cuda()
                z_test = z_test.cuda()
                z_test_target = z_test_target.cuda()
                #class_vector = class_vector.cuda()

            masked_z_test_target = z_test_target * mask
            obscured_z_test_target = z_test_target * (1.0 - mask)

            if 'perturbation' in args.loss_fn:
                # With backprop on only trainable parameters in perturbation net
                params = trainable_params + [z_test.requires_grad_()]
                z_optimizer = util.get_optimizer(params, args)
            else:
                # With backprop on only the input z, run one step of z-test and get z-loss
                z_optimizer = util.get_optimizer([z_test.requires_grad_()],
                                                 args)

            with torch.set_grad_enabled(True):

                if class_vector is not None:
                    z_probs = model.forward(z_test, class_vector,
                                            args.truncation).float()
                    z_probs = (z_probs + 1) / 2.
                else:
                    z_probs = model.forward(z_test).float()

                # Calculate the masked loss using z-test vector
                masked_z_probs = z_probs * mask
                z_loss = torch.zeros(1, requires_grad=True).to(args.device)

                pixel_loss = torch.zeros(1, requires_grad=True).to(args.device)
                pixel_loss = pixel_criterion(masked_z_probs,
                                             masked_z_test_target)

                if 'perceptual' in args.loss_fn:
                    z_probs_features = vgg_feature_extractor(masked_z_probs)
                    z_test_features = vgg_feature_extractor(
                        masked_z_test_target).detach()

                    perceptual_loss = torch.zeros(1, requires_grad=True).to(
                        args.device)
                    perceptual_loss = perceptual_criterion(
                        z_probs_features, z_test_features)

                    z_loss = pixel_loss + perceptual_loss_weight * perceptual_loss
                elif 'perturbation' in args.loss_fn:
                    reg_loss = torch.zeros(1,
                                           requires_grad=True).to(args.device)
                    for name, param in model.named_parameters():
                        if 'perturb' in name:
                            delta = param - 1
                            reg_loss += torch.pow(delta, 2).mean()  #sum()
                    z_loss = pixel_loss + reg_loss_weight * reg_loss
                else:
                    z_loss = pixel_loss

                # Backprop on z-test vector
                z_loss.backward()
                z_optimizer.step()
                z_optimizer.zero_grad()

            # Compute the full loss (without mask) and obscured loss (loss only on masked region)
            # For logging and final evaluation (obscured loss is final MSE), so not in backprop loop
            full_z_loss = torch.zeros(1)
            full_pixel_loss = torch.zeros(1)
            full_pixel_loss = pixel_criterion(z_probs, z_test_target)  #.mean()

            obscured_z_probs = z_probs * (1.0 - mask)
            obscured_z_loss = torch.zeros(1)
            obscured_pixel_loss = torch.zeros(1)
            obscured_pixel_loss = pixel_criterion(
                obscured_z_probs, obscured_z_test_target)  #.mean()

            if 'perceptual' in args.loss_fn:
                # Full loss
                z_probs_full_features = vgg_feature_extractor(z_probs).detach()
                z_test_full_features = vgg_feature_extractor(
                    z_test_target).detach()

                full_perceptual_loss = torch.zeros(1)
                full_perceptual_loss = perceptual_criterion(
                    z_probs_full_features, z_test_full_features)

                full_z_loss = full_pixel_loss + perceptual_loss_weight * full_perceptual_loss

                # Obscured loss
                z_probs_obscured_features = vgg_feature_extractor(
                    z_probs).detach()
                z_test_obscured_features = vgg_feature_extractor(
                    z_test_target).detach()

                obscured_perceptual_loss = torch.zeros(1)
                obscured_perceptual_loss = perceptual_criterion(
                    z_probs_obscured_features, z_test_obscured_features)

                obscured_z_loss = obscured_pixel_loss + perceptual_loss_weight * obscured_perceptual_loss
            elif 'perturbation' in args.loss_fn:
                full_z_loss = full_pixel_loss + reg_loss_weight * reg_loss
                obscured_z_loss = obscured_pixel_loss + reg_loss_weight * reg_loss
            else:
                full_z_loss = full_pixel_loss
                obscured_z_loss = obscured_pixel_loss
            """# TODO: z_loss is not always MSE anymore - figure out desired metric
            if z_loss < max_z_test_loss:
                # Save MSE on obscured region # TODO: z_loss is not always MSE anymore - figure out desired metric
                final_metrics = {'z-loss': z_loss.item(), 'obscured-z-loss': obscured_z_loss.item()}
                logger._log_scalars(final_metrics)
                print('Recall (z loss - non obscured loss - if MSE)', z_loss) 
                print('Precision (MSE value on masked region)', obscured_z_loss)
            """

            # Log both train and eval model settings, and visualize their outputs
            logger.log_status(
                masked_probs=masked_z_probs,
                masked_loss=z_loss,
                masked_test_target=masked_z_test_target,
                full_probs=z_probs,
                full_loss=full_z_loss,
                full_test_target=z_test_target,
                obscured_probs=obscured_z_probs,
                obscured_loss=obscured_z_loss,
                obscured_test_target=obscured_z_test_target,
                save_preds=args.save_preds,
            )

            logger.end_iter()

        logger.end_epoch()

    # Last log after everything completes
    logger.log_status(
        masked_probs=masked_z_probs,
        masked_loss=z_loss,
        masked_test_target=masked_z_test_target,
        full_probs=z_probs,
        full_loss=full_z_loss,
        full_test_target=z_test_target,
        obscured_probs=obscured_z_probs,
        obscured_loss=obscured_z_loss,
        obscured_test_target=obscured_z_test_target,
        save_preds=args.save_preds,
        force_visualize=True,
    )
def train_user_pred(optims,
                    generator,
                    bsize,
                    embed_dim,
                    recom_length,
                    trainSample,
                    validSample,
                    testSample,
                    mode='generator with rec',
                    inner_val_acc_best=None,
                    inner_val_preck_best=None,
                    inner_val_rewd_best=None,
                    inner_loss_best=None,
                    only_rewards=False,
                    n_epochs=10):
    outputdir = "model_output"
    outputmodelname = "simu.model.pth"
    lrshrink = 5
    minlr = 1e-5
    generator_only = True
    action_given = True
    #Define the optimizers
    #loss_fn_target = nn.CrossEntropyLoss(weight)
    loss_fn_target = nn.CrossEntropyLoss()
    #loss_fn_reward = nn.MSELoss()
    loss_fn_reward = nn.BCEWithLogitsLoss()
    loss_fn_target.size_average = True
    loss_fn_target.to(device)
    loss_fn_reward.size_average = True
    loss_fn_reward.to(device)

    optim_fn, optim_params = get_optimizer(optims)
    if mode == 'generator':
        params = list(generator.parameters())
        action_given = False
    elif mode == 'generator with rec':
        params = list(generator.parameters())
        action_given = True
    else:
        print("No such mode! Select from generator/generator with rec!")

    optimizer = optim_fn(filter(lambda p: p.requires_grad, params),
                         **optim_params)

    #n_epochs=10
    if inner_val_acc_best == None:
        inner_val_acc_best = -1e10
        inner_val_preck_best = -1e10
        inner_val_rewd_best = -1e10
        inner_loss_best = 1e10
    stop_training = False
    times_no_improvement = 0
    epoch = 1
    eval_type = 'valid'
    best_model = generator
    while not stop_training and epoch <= n_epochs:
        if not only_rewards:
            #Train click
            #print("Clicks!")
            train_acc, train_preck, _ = train_pred_each(
                generator, epoch, trainSample, optimizer, bsize, embed_dim,
                recom_length, loss_fn_target, loss_fn_reward, device,
                generator_only, action_given, False)
        #Evaluate without EOS
        print("User model evaluation!")
        eval_acc, eval_preck, eval_rewd, eval_loss = evaluate_user(
            generator, epoch, bsize, recom_length - 1, validSample, testSample,
            loss_fn_target, loss_fn_reward, device, eval_type)
        # save model
        if eval_type == 'valid' and epoch <= n_epochs:
            if eval_acc > inner_val_acc_best or eval_preck > inner_val_preck_best:
                #if inner_loss_best >= eval_loss:
                best_model = generator
                print('saving model at epoch {0}'.format(epoch))
                if not os.path.exists(outputdir):
                    os.makedirs(outputdir)
                torch.save(
                    generator.state_dict(),
                    os.path.join(outputdir, 'irecGan_gen3.' + outputmodelname))
                inner_val_acc_best = eval_acc
                inner_val_preck_best = eval_preck
                inner_val_rewd_best = eval_rewd
                inner_loss_best = eval_loss
                times_no_improvement = 0
            else:
                times_no_improvement += 1
                stop_training = adj_optim(optims, optimizer, minlr, lrshrink,
                                          stop_training, times_no_improvement)
        epoch += 1
    return best_model, inner_val_acc_best, inner_val_preck_best, inner_val_rewd_best, inner_loss_best
コード例 #19
0
  def __init__(self, is_training, config, input_, vocab_size, pretrained_emb):
    self._is_training = is_training
    self._input = input_
    self._rnn_params = None
    self._cell = None
    self.batch_size = input_.batch_size
    self.num_steps = input_.num_steps
    self._config = config
    self._pretrained_emb = pretrained_emb
    size = config.hidden_size

    inputs = None
    if FLAGS.use_recoding:
      if config.ec_code_generator == "preassign":
        is_one_hot = False
      else:
        is_one_hot = True
      encoder = Encoder(K=config.K,
                        D=config.D,
                        d=config.ec_code_emb_dim,
                        outd=size,
                        hparams=config,
                        vocab_size=vocab_size,
                        code_type=config.code_type,
                        code_initializer=FLAGS.code_load_filename,
                        emb_baseline=config.ec_emb_baseline,
                        code_emb_initializer=None,
                        create_code_logits=True,
                        pretrained_emb=pretrained_emb)
      if not config.disable_encoding:
        if config.ec_emb_baseline:
          code_logits = None
          if config.ec_emb_autoencoding and is_training:
            print("emb_baseline with auto-encoding enabled.")
            embsb = tf.nn.embedding_lookup(pretrained_emb, input_.input_data)
            embsb = tf.stop_gradient(embsb)
            code_logits_b = encoder.embed_transpose(
                embsb, is_training=is_training)  # (bs, steps, D, K)
            if config.ec_temperature_decay_method == "none":
              center, scale = True, True
            else:
              center, scale = False, False
            with tf.variable_scope("autoencoding_code_logits_bn", reuse=False):
              if config.ec_logits_bn > 0:
                code_logits_b = config.ec_logits_bn * tf.layers.batch_normalization(
                    code_logits_b, training=is_training, center=center, scale=scale)
            codes_m, _, code_logits = encoder.symbol2code(
                input_.input_data, is_training=is_training, output_logits=True)
            codes_b, _ = encoder.symbol2code(  # has to be after codes_m for bn issue.
                input_.input_data, logits=code_logits_b,
                logits_bn_overwrite=0, is_training=is_training)
            codes_concat = tf.concat([codes_b, codes_m], 0)
            embs = encoder.embed(  # Shared code->emb function.
                codes_concat, is_one_hot=is_one_hot, is_training=is_training)
            embsb_reconst, embs_m = tf.split(embs, 2, axis=0)
            inputs = embs_m  # (batch_size, steps, emb_dim)
            # Add regularization.
            regl_embs = tf.reduce_mean((embsb - embs_m)**2)
            regl_reconst = tf.reduce_mean((embsb - embsb_reconst)**2)
            regl_logits = tf.reduce_mean((code_logits_b - code_logits)**2)
            if is_training:
              tf.summary.scalar("regl_embs", regl_embs)
              tf.summary.scalar("regl_reconst", regl_reconst)
              tf.summary.scalar("regl_logits", regl_logits)
            emb_baseline_loss = (regl_embs * config.ec_emb_baseline_reg +
                                 regl_reconst * config.ec_emb_baseline_reg2 +
                                 regl_logits * config.ec_emb_baseline_reg3)
            tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                                 emb_baseline_loss)
          else:
            input_codes, code_embs, embsb = encoder.symbol2code(
                input_.input_data, logits=code_logits,
                is_training=is_training, output_embb=True)
        else:  # no emb baseline.
          if config.ec_emb_autoencoding:
            raise ValueError("ec_emb_autoencoding should be False when "
                             "ec_emb_baseline is False. Please check!")
          embsb = None
          input_codes, code_embs = encoder.symbol2code(
              input_.input_data, is_training=is_training)

        if inputs is None:  # will not enter here if ec_emb_autoencoding=True.
          inputs = encoder.embed(input_codes,
                                 code_embs=code_embs,
                                 embsb=embsb,
                                 is_one_hot=is_one_hot,
                                 is_training=is_training)

        self.query_codes, _code_embs = encoder.symbol2code(
            tf.range(vocab_size), is_training=False)
        self.query_input_emb = encoder.embed(
            self.query_codes, _code_embs,
            is_one_hot=is_one_hot, is_training=False)

    if inputs is None:
      with tf.device("/gpu:0"):
        if config.emb_lowrank_dim == 0:
          embedding = tf.get_variable(
              "embedding", [vocab_size, size], dtype=data_type())
          if pretrained_emb is not None:
            self.using_pretrained_embs_on_run_op = (
                reload_embedding_after_checkpoint_recover(pretrained_emb))
          inputs = tf.nn.embedding_lookup(embedding, input_.input_data)
          self.query_codes = None
          self.query_input_emb = embedding
        else:
          _dim = config.emb_lowrank_dim
          if _dim < 1.:  # Represents ratio of low-rank parameters to full one.
            _dim = int(_dim * vocab_size / (vocab_size + size) * size)
          embedding_p = tf.get_variable(
              "embedding_p", [vocab_size, _dim], dtype=data_type())
          embedding_q = tf.get_variable(
              "embedding_q", [_dim, size], dtype=data_type())
          inputs = tf.nn.embedding_lookup(embedding_p, input_.input_data)
          inputs = tf.tensordot(inputs, embedding_q, [[-1], [0]])
          self.query_codes = None
          self.query_input_emb = embedding_p
    if config.ec_code_generator == "preassign":
        self.query_codes = None  # save memory and space.

    targets = input_.targets
    self.batch_size_final = self.batch_size
    outputs, state = self._build_rnn_graph(inputs, config, is_training)

    if config.shared_embedding:
      softmax_w = tf.transpose(embedding)
    else:
      softmax_w = tf.get_variable(
          "softmax_w", [size, vocab_size], dtype=data_type())
    self.query_output_emb = tf.transpose(softmax_w)
    softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type())
    logits = tf.nn.xw_plus_b(outputs, softmax_w, softmax_b)
     # Reshape logits to be a 3-D tensor for sequence loss
    logits = tf.reshape(logits,
                        [self.batch_size_final, self.num_steps, vocab_size])
    loss = tf.contrib.seq2seq.sequence_loss(
        logits,
        targets,
        tf.ones([self.batch_size_final, self.num_steps], dtype=data_type()),
        average_across_timesteps=False,
        average_across_batch=False)  # (batch_size, num_steps)

    # Update the cost
    self._nll = tf.reduce_sum(tf.reduce_mean(loss, 0))
    self._cost = self._nll
    self._final_state = state

    if not is_training:
      return


    # Add regularization.
    self._cost += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

    # Add weight decay.
    if config.weight_decay != 0.:
      for tvar in tf.trainable_variables():
        self._cost += config.weight_decay * 0.5 * tf.reduce_sum(tf.square(tvar))

    self._lr = tf.Variable(0.0, trainable=False)
    non_pg_tvars = tf.trainable_variables()
    if config.optimizer == "mixed":
      if config.ec_code_generator == "preassign":
        raise ValueError("Shouldn't use mixed when using preassign.")
      mixed_encode = False
      pg_tvars = tf.get_collection(
          tf.GraphKeys.TRAINABLE_VARIABLES, scope="Model/code/code_logits")
      pg_tvars += tf.get_collection(
          tf.GraphKeys.TRAINABLE_VARIABLES, scope="Model/symbol2code")
      if mixed_encode:
        pg_tvars += tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, scope="Model/encode")
      tvars, non_pg_tvars = non_pg_tvars, []
      for tvar in tvars:
        if tvar.name.startswith("Model/code/code_logits") or (
            tvar.name.startswith("Model/symbol2code")):
          continue
        if mixed_encode and tvar.name.startswith("Model/encode"):
          continue
        non_pg_tvars.append(tvar)
      if len(tvars) == len(non_pg_tvars) or (
          len(pg_tvars) + len(non_pg_tvars) != len(tvars)):
        raise ValueError("Check pg_tvars and non_pg_tvars separation!")
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        grads = tf.gradients(self._cost, pg_tvars + non_pg_tvars)
        # Globally clip_gradients treatment.
        # grads, _ = tf.clip_by_global_norm(grads, config.max_grad_norm)
        grads_pg = [grads[i] for i in range(len(pg_tvars))]
        _start = len(pg_tvars)
        grads_non_pg = [grads[i + _start] for i in range(len(non_pg_tvars))]
        tf.summary.scalar("pg_grad_norm", tf.global_norm(grads_pg))
        tf.summary.scalar("nonpg_grad_norm", tf.global_norm(grads_non_pg))
        # Only clip on the grads_non_pg, instead of global treatment.
        grads_non_pg, _ = tf.clip_by_global_norm(grads_non_pg,
                                                 config.max_grad_norm)
        optimizer_pg = tf.contrib.opt.LazyAdamOptimizer(self._lr / 100.)
        optimizer_nonpg = tf.train.GradientDescentOptimizer(self._lr)
        train_op_pg = optimizer_pg.apply_gradients(zip(grads_pg, pg_tvars))
        train_op_nonpg = optimizer_nonpg.apply_gradients(
            zip(grads_non_pg, non_pg_tvars),
            global_step=tf.train.get_or_create_global_step())
      self._train_op = tf.group(train_op_pg, train_op_nonpg)
    elif config.optimizer == "scheduled_sgd":
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        grads = tf.gradients(self._cost, non_pg_tvars)
        tf.summary.scalar("global_grad_norm", tf.global_norm(grads))
        grads, _ = tf.clip_by_global_norm(grads,
                                          config.max_grad_norm)
        optimizer = tf.train.GradientDescentOptimizer(self._lr)
        self._train_op = optimizer.apply_gradients(
            zip(grads, non_pg_tvars),
            global_step=tf.train.get_or_create_global_step())
    else:
      self._train_op = tf.contrib.layers.optimize_loss(
          loss=self._cost,
          global_step=tf.train.get_or_create_global_step(),
          learning_rate=tf.convert_to_tensor(self._lr),
          optimizer=util.get_optimizer(config.optimizer),
          variables=non_pg_tvars,
          clip_gradients=float(config.max_grad_norm),
          summaries=["learning_rate", "loss", "global_gradient_norm"])
    self._new_lr = tf.placeholder(
        tf.float32, shape=[], name="new_learning_rate")
    self._lr_update = tf.assign(self._lr, self._new_lr)
コード例 #20
0
def train(args):

    if args.ckpt_path and not args.use_pretrained:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        model = model_fn(**vars(args))
        if args.use_pretrained:
            model.load_pretrained(args.ckpt_path, args.gpu_ids)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    if args.use_pretrained or args.fine_tune:
        parameters = model.module.fine_tuning_parameters(
            args.fine_tuning_boundary, args.fine_tuning_lr)
    else:
        parameters = model.parameters()
    optimizer = util.get_optimizer(parameters, args)
    lr_scheduler = util.get_scheduler(optimizer, args)
    if args.ckpt_path and not args.use_pretrained and not args.fine_tune:
        ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler)

    # Get logger, evaluator, saver
    cls_loss_fn = util.get_loss_fn(is_classification=True,
                                   dataset=args.dataset,
                                   size_average=False)
    data_loader_fn = data_loader.__dict__[args.data_loader]
    train_loader = data_loader_fn(args, phase='train', is_training=True)
    logger = TrainLogger(args, len(train_loader.dataset),
                         train_loader.dataset.pixel_dict)
    eval_loaders = [data_loader_fn(args, phase='val', is_training=False)]
    evaluator = ModelEvaluator(args.do_classify, args.dataset, eval_loaders,
                               logger, args.agg_method, args.num_visuals,
                               args.max_eval, args.epochs_per_eval)
    saver = ModelSaver(args.save_dir, args.epochs_per_save, args.max_ckpts,
                       args.best_ckpt_metric, args.maximize_metric)

    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, target_dict in train_loader:
            logger.start_iter()

            with torch.set_grad_enabled(True):
                inputs.to(args.device)
                cls_logits = model.forward(inputs)
                cls_targets = target_dict['is_abnormal']
                cls_loss = cls_loss_fn(cls_logits, cls_targets.to(args.device))
                loss = cls_loss.mean()

                logger.log_iter(inputs, cls_logits, target_dict,
                                cls_loss.mean(), optimizer)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            logger.end_iter()
            util.step_scheduler(lr_scheduler, global_step=logger.global_step)

        metrics, curves = evaluator.evaluate(model, args.device, logger.epoch)
        saver.save(logger.epoch,
                   model,
                   optimizer,
                   lr_scheduler,
                   args.device,
                   metric_val=metrics.get(args.best_ckpt_metric, None))
        logger.end_epoch(metrics, curves)
        util.step_scheduler(lr_scheduler,
                            metrics,
                            epoch=logger.epoch,
                            best_ckpt_metric=args.best_ckpt_metric)
コード例 #21
0
  def forward(self, features, labels, is_training=False):
    """Returns loss, preds, train_op.

    Args:
      features: (batch_size, max_seq_length)
      labels: (batch_size, num_classes)

    Returns:
      loss: (batch_size, )
      preds: (batch_size, )
      train_op: op.
    """
    num_classes = labels.shape.as_list()[-1]
    batch_size = tf.shape(features)[0]
    mask = tf.cast(tf.greater(features, 0), tf.float32)  # (bs, max_seq_length)
    lengths = tf.reduce_sum(mask, axis=1, keepdims=True)  # (batch_size, 1)

    # Embedding
    if FLAGS.kdq_type == "none":
      inputs = full_embed(features, FLAGS.vocab_size, FLAGS.dims)
    else:
      kdq_hparam = KDQhparam(
          K=FLAGS.K, D=FLAGS.D, kdq_type=FLAGS.kdq_type,
          kdq_d_in=FLAGS.kdq_d_in, kdq_share_subspace=FLAGS.kdq_share_subspace,
          additive_quantization=FLAGS.additive_quantization)
      inputs = kdq_embed(
          features, FLAGS.vocab_size, FLAGS.dims, kdq_hparam, is_training)
    word_embs = inputs  # (bs, length, emb_dim)
    word_embs *= tf.expand_dims(mask, -1)

    embs_maxpool = tf.reduce_max(word_embs, 1)  # Max pooling.
    embs_meanpool = tf.reduce_sum(word_embs, 1) / lengths  # Mean pooling.
    if FLAGS.concat_maxpooling:
      embs = tf.concat([embs_meanpool, embs_maxpool], -1)
    else:
      embs = embs_meanpool
    if FLAGS.hidden_layers > 0:
      embs = tf.nn.relu(
          tf.layers.batch_normalization(embs, training=is_training))
      embs = tf.layers.dense(embs, FLAGS.dims)
      embs = tf.nn.relu(
          tf.layers.batch_normalization(embs, training=is_training))
    logits = tf.layers.dense(embs, num_classes)
    preds = tf.argmax(logits, -1)[:batch_size]
    loss = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=labels, logits=logits)

    if is_training:
      # Regular loss updater.
      loss_scalar = tf.reduce_mean(loss)
      loss_scalar += FLAGS.reg_weight * tf.reduce_mean(word_embs**2)
      loss_scalar += tf.reduce_sum(
          tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
      train_op = tf.contrib.layers.optimize_loss(
          loss=loss_scalar,
          global_step=tf.train.get_or_create_global_step(),
          learning_rate=FLAGS.learning_rate,
          optimizer=util.get_optimizer(FLAGS.optimizer),
          variables=tf.trainable_variables())
    else:
      train_op = False
      loss_scalar = None

    return loss_scalar, preds, train_op