示例#1
0
        cfg_from_file(args.cfg_file)
    if args.gpu_id != -1:
        cfg.GPU_ID = args.gpu_id
    print('Using config:')
    pprint.pprint(cfg)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')

    datadir = 'Data/%s' % cfg.DATASET_NAME
    dataset = TextDataset(datadir, cfg.EMBEDDING_TYPE, 1)
    filename_test = '%s/test' % (datadir)
    dataset.test = dataset.get_data(filename_test)
    if cfg.TRAIN.FLAG:
        filename_train = '%s/train' % (datadir)
        dataset.train = dataset.get_data(filename_train)
        ckt_logs_dir = "ckt_logs/%s/%s_%s" % (cfg.DATASET_NAME,
                                              cfg.CONFIG_NAME, timestamp)
        mkdir_p(ckt_logs_dir)
    else:
        s_tmp = cfg.TRAIN.PRETRAINED_MODEL
        ckt_logs_dir = s_tmp[:s_tmp.find('.ckpt')]

    model = CondGAN(image_shape=dataset.image_shape)
    algo = CondGANTrainer(model=model,
                          dataset=dataset,
                          ckt_logs_dir=ckt_logs_dir)
    if cfg.TRAIN.FLAG:
        algo.train()
    else:
        ''' For every input text embedding/sentence in the
示例#2
0
    print('Using config:')
    pprint.pprint(cfg)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')

    #    datadir = 'Data/%s' % cfg.DATASET_NAME
    datadir = cfg.DATASET_NAME
    dataset = TextDataset(datadir, cfg.EMBEDDING_TYPE, 4)
    filename_test = datadir
    #    dataset.test = dataset.get_data(filename_test)
    dataset.test = dataset.get_data(cfg.DATASET_NAME)
    if cfg.TRAIN.FLAG:
        filename_train = datadir
        #        dataset.train = dataset.get_data(filename_train)
        dataset.train = dataset.get_data(cfg.DATASET_NAME)
        ckt_logs_dir = "ckt_logs/%s/%s_%s" % \
            (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
        mkdir_p(ckt_logs_dir)
    else:
        s_tmp = cfg.TRAIN.PRETRAINED_MODEL
        ckt_logs_dir = s_tmp[:s_tmp.find('.ckpt')]

    model = CondGAN(lr_imsize=int(dataset.image_shape[0] /
                                  dataset.hr_lr_ratio),
                    hr_lr_ratio=dataset.hr_lr_ratio)

    algo = CondGANTrainer(model=model,
                          dataset=dataset,
                          ckt_logs_dir=ckt_logs_dir)
示例#3
0
from misc.datasets import TextDataset
from model import CondGAN
from trainer import CondGANTrainer
from misc.get_configs import parse_args
from misc.utils import mkdir_p

if __name__ == "__main__":
    args = parse_args()
    print(args)
    now = datetime.datetime.now()
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')

    dataset = TextDataset(datadir='datasets/' + args.dataset + '/')

    print("Dataset created!")
    dataset.train = dataset.get_data()

    model = CondGAN(args, image_shape=dataset.image_shape)
    print("model created!")

    # if args.for_training:
    ckt_logs_dir = "ckt_logs/%s" % \
        ("{}_logs".format(args.dataset))
    res_dir = "retrieved_res/%s" % \
        ("{}_res".format(args.dataset))
    mkdir_p(ckt_logs_dir)
    mkdir_p(res_dir)
    with open(ckt_logs_dir + '/args.txt', 'w') as fid:
        fid.write(str(args) + '\n')

    algo = CondGANTrainer(args,
示例#4
0
        cfg_from_file(args.cfg_file)
    if args.gpu_id != -1:
        cfg.GPU_ID = args.gpu_id
    print('Using config:')
    pprint.pprint(cfg)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')

    datadir = 'Data/%s' % cfg.DATASET_NAME
    dataset = TextDataset(datadir, cfg.EMBEDDING_TYPE, 1)
    filename_test = '%s/test' % (datadir)
    dataset.test = dataset.get_data(filename_test)
    if cfg.TRAIN.FLAG:
        filename_train = '%s/train' % (datadir)
        dataset.train = dataset.get_data(filename_train)

        ckt_logs_dir = "ckt_logs/%s/%s_%s" % \
            (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
        mkdir_p(ckt_logs_dir)
    else:
        s_tmp = cfg.TRAIN.PRETRAINED_MODEL
        ckt_logs_dir = s_tmp[:s_tmp.find('.ckpt')]

    model = CondGAN(
        image_shape=dataset.image_shape
    )

    algo = CondGANTrainer(
        model=model,
        dataset=dataset,
示例#5
0
        cfg_from_file(args.cfg_file)
    if args.gpu_id != -1:
        cfg.GPU_ID = args.gpu_id
    print('Using config:')
    pprint.pprint(cfg)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    #    datadir = 'Data/%s' % cfg.DATASET_NAME
    datadir = cfg.DATASET_NAME
    dataset = TextDataset(datadir, cfg.EMBEDDING_TYPE, 1)
    filename_test = datadir
    #    dataset.test = dataset.get_data(filename_test,aug_flag=False)
    if cfg.TRAIN.FLAG:
        filename_train = datadir
        dataset.train = dataset.get_data(filename_train, aug_flag=False)

        ckt_logs_dir = "ckt_logs/%s/%s_%s" % \
            (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
        mkdir_p(ckt_logs_dir)
    else:
        s_tmp = cfg.TRAIN.PRETRAINED_MODEL
        ckt_logs_dir = s_tmp[:s_tmp.find('.ckpt')]

    model = CondGAN(image_shape=dataset.image_shape)

    algo = CondGANTrainer(model=model,
                          dataset=dataset,
                          ckt_logs_dir=ckt_logs_dir)
    if cfg.TRAIN.FLAG:
        algo.train()