示例#1
0
    def __init__(self, **kwrds):
        self.config = utils.Config(copy.deepcopy(const.config))
        for key in kwrds.keys():
            assert key in self.config.keys(), '{} is not a keyword, \n acceptable keywords: {}'.\
                format(key, self.config.keys())
            self.config[key] = kwrds[key]

        self.experiments_root_dir = 'experiments'
        utils.create_dirs([self.experiments_root_dir])
        self.config.model_name = const.get_model_name(self.config.model_type,
                                                      self.config)
        self.config.checkpoint_dir = os.path.join(
            self.experiments_root_dir + "/" + self.config.checkpoint_dir + "/",
            self.config.model_name)
        self.config.summary_dir = os.path.join(
            self.experiments_root_dir + "/" + self.config.summary_dir + "/",
            self.config.model_name)
        self.config.log_dir = os.path.join(
            self.experiments_root_dir + "/" + self.config.log_dir + "/",
            self.config.model_name)

        utils.create_dirs([
            self.config.checkpoint_dir, self.config.summary_dir,
            self.config.log_dir
        ])
        load_config = {}
        try:
            load_config = utils.load_args(self.config.model_name,
                                          self.config.summary_dir)
            self.config.update(load_config)
            self.config.update({
                key: const.config[key]
                for key in ['kinit', 'bias_init', 'act_out', 'transfer_fct']
            })
            print('Loading previous configuration ...')
        except:
            print('Unable to load previous configuration ...')

        utils.save_args(self.config.dict(), self.config.model_name,
                        self.config.summary_dir)

        if self.config.plot:
            self.latent_space_files = list()
            self.latent_space3d_files = list()
            self.recons_files = list()

        if hasattr(self.config, 'height'):
            try:
                self.config.restore = True
                self.build_model(self.config.height, self.config.width,
                                 self.config.num_channels)
            except:
                self.isBuild = False
        else:
            self.isBuild = False
示例#2
0
文件: train.py 项目: xlnwel/cpcgan
def train_cpcgan(cpc_epochs, gan_epochs):
    name = 'cpcgan'
    cpcgan_args = utils.load_args()[name]
    batch_size = cpcgan_args['batch_size']
    terms = cpcgan_args['cpc']['hist_terms']
    predict_terms = cpcgan_args['cpc']['future_terms']
    image_size = cpcgan_args['image_shape'][0]
    color = cpcgan_args['color']

    train_data = SortedNumberGenerator(batch_size=batch_size,
                                       subset='train',
                                       terms=terms,
                                       positive_samples=batch_size // 2,
                                       predict_terms=predict_terms,
                                       image_size=image_size,
                                       color=color,
                                       rescale=False)

    validation_data = SortedNumberGenerator(batch_size=batch_size,
                                            subset='valid',
                                            terms=terms,
                                            positive_samples=batch_size // 2,
                                            predict_terms=predict_terms,
                                            image_size=image_size,
                                            color=color,
                                            rescale=False)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    cpcgan = timeit(lambda: CPCGAN(
        name, cpcgan_args, sess=sess, reuse=False, log_tensorboard=True),
                    name='CPCGAN')

    # cpcgan.restore_cpc()  # restore cpc only
    cpcgan.restore()  # restore the entire cpcgan
    if cpc_epochs != 0:
        train_cpc(cpcgan, cpc_epochs, train_data, validation_data)

    train_gan(cpcgan, gan_epochs, train_data, validation_data)
示例#3
0
import gym
import random
import tensorflow
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
import utils.utils as utils
import tensorflow as tf
from ddpg_tf import DDPG

env = gym.make('BipedalWalker-v2')
env.seed(0)

sess = tf.Session()
agent = DDPG('ddpg', utils.load_args(), sess=sess)
agent.restore()
def ddpg(n_episodes=10000, max_t=1000):
    scores_deque = deque(maxlen=100)
    scores = []
    for i_episode in range(1, n_episodes+1):
        state = env.reset()
        score = 0
        for t in range(max_t):
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            agent.step(state, action, reward, next_state, done)
            state = next_state
            score += reward
            if done:
                break 
        scores_deque.append(score)
示例#4
0
def main():
    # Get arguments and start logging
    args = parser.parse_args()
    load_args(args.config, args)
    logpath = os.path.join(args.cv_dir, args.name)
    os.makedirs(logpath, exist_ok=True)
    save_args(args, logpath, args.config)
    writer = SummaryWriter(log_dir = logpath, flush_secs = 30)

    # Get dataset
    trainset = dset.CompositionDataset(
        root=os.path.join(DATA_FOLDER,args.data_dir),
        phase='train',
        split=args.splitname,
        model =args.image_extractor,
        num_negs=args.num_negs,
        pair_dropout=args.pair_dropout,
        update_features = args.update_features,
        train_only= args.train_only,
        open_world=args.open_world
    )
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers)
    testset = dset.CompositionDataset(
        root=os.path.join(DATA_FOLDER,args.data_dir),
        phase=args.test_set,
        split=args.splitname,
        model =args.image_extractor,
        subset=args.subset,
        update_features = args.update_features,
        open_world=args.open_world
    )
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=args.test_batch_size,
        shuffle=False,
        num_workers=args.workers)


    # Get model and optimizer
    image_extractor, model, optimizer = configure_model(args, trainset)
    args.extractor = image_extractor

    train = train_normal

    evaluator_val =  Evaluator(testset, model)

    print(model)

    start_epoch = 0
    # Load checkpoint
    if args.load is not None:
        checkpoint = torch.load(args.load)
        if image_extractor:
            try:
                image_extractor.load_state_dict(checkpoint['image_extractor'])
                if args.freeze_features:
                    print('Freezing image extractor')
                    image_extractor.eval()
                    for param in image_extractor.parameters():
                        param.requires_grad = False
            except:
                print('No Image extractor in checkpoint')
        model.load_state_dict(checkpoint['net'])
        start_epoch = checkpoint['epoch']
        print('Loaded model from ', args.load)
    
    for epoch in tqdm(range(start_epoch, args.max_epochs + 1), desc = 'Current epoch'):
        train(epoch, image_extractor, model, trainloader, optimizer, writer)
        if model.is_open and args.model=='compcos' and ((epoch+1)%args.update_feasibility_every)==0 :
            print('Updating feasibility scores')
            model.update_feasibility(epoch+1.)

        if epoch % args.eval_val_every == 0:
            with torch.no_grad(): # todo: might not be needed
                test(epoch, image_extractor, model, testloader, evaluator_val, writer, args, logpath)
    print('Best AUC achieved is ', best_auc)
    print('Best HM achieved is ', best_hm)
示例#5
0
 def _get_models(self):
     return utils.load_args('models.yaml')
示例#6
0
def main():
    # Get arguments and start logging
    args = parser.parse_args()
    logpath = args.logpath
    config = [
        os.path.join(logpath, _) for _ in os.listdir(logpath)
        if _.endswith('yml')
    ][0]
    load_args(config, args)

    # Get dataset
    trainset = dset.CompositionDataset(root=os.path.join(
        DATA_FOLDER, args.data_dir),
                                       phase='train',
                                       split=args.splitname,
                                       model=args.image_extractor,
                                       update_features=args.update_features,
                                       train_only=args.train_only,
                                       subset=args.subset,
                                       open_world=args.open_world)

    valset = dset.CompositionDataset(root=os.path.join(DATA_FOLDER,
                                                       args.data_dir),
                                     phase='val',
                                     split=args.splitname,
                                     model=args.image_extractor,
                                     subset=args.subset,
                                     update_features=args.update_features,
                                     open_world=args.open_world)

    valoader = torch.utils.data.DataLoader(valset,
                                           batch_size=args.test_batch_size,
                                           shuffle=False,
                                           num_workers=8)

    testset = dset.CompositionDataset(root=os.path.join(
        DATA_FOLDER, args.data_dir),
                                      phase='test',
                                      split=args.splitname,
                                      model=args.image_extractor,
                                      subset=args.subset,
                                      update_features=args.update_features,
                                      open_world=args.open_world)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.test_batch_size,
                                             shuffle=False,
                                             num_workers=args.workers)

    # Get model and optimizer
    image_extractor, model, optimizer = configure_model(args, trainset)
    args.extractor = image_extractor

    args.load = ospj(logpath, 'ckpt_best_auc.t7')

    checkpoint = torch.load(args.load)
    if image_extractor:
        try:
            image_extractor.load_state_dict(checkpoint['image_extractor'])
            image_extractor.eval()
        except:
            print('No Image extractor in checkpoint')
    model.load_state_dict(checkpoint['net'])
    model.eval()

    threshold = None
    if args.open_world and args.hard_masking:
        assert args.model == 'compcos', args.model + ' does not have hard masking.'
        if args.threshold is not None:
            threshold = args.threshold
        else:
            evaluator_val = Evaluator(valset, model)
            unseen_scores = model.compute_feasibility().to('cpu')
            seen_mask = model.seen_mask.to('cpu')
            min_feasibility = (unseen_scores + seen_mask * 10.).min()
            max_feasibility = (unseen_scores - seen_mask * 10.).max()
            thresholds = np.linspace(min_feasibility,
                                     max_feasibility,
                                     num=args.threshold_trials)
            best_auc = 0.
            best_th = -10
            with torch.no_grad():
                for th in thresholds:
                    results = test(image_extractor,
                                   model,
                                   valoader,
                                   evaluator_val,
                                   args,
                                   threshold=th,
                                   print_results=False)
                    auc = results['AUC']
                    if auc > best_auc:
                        best_auc = auc
                        best_th = th
                        print('New best AUC', best_auc)
                        print('Threshold', best_th)

            threshold = best_th

    evaluator = Evaluator(testset, model)

    with torch.no_grad():
        test(image_extractor, model, testloader, evaluator, args, threshold)