Example #1
0
    def __init__(self, cfg):
        img_path = os.path.join(cfg.log_path, cfg.exp_name, 'train_images_*')
        if glob.glob(img_path):
            raise Exception('all directories with name train_images_* under '
                            'the experiment directory need to be removed')
        path = os.path.join(cfg.log_path, cfg.exp_name)

        self.model = MODELS[cfg.gan_type](cfg)
        self.model.save_model(path, 0, 0)

        if cfg.load_snapshot is not None:
            self.model.load_model(cfg.load_snapshot)
        shuffle = cfg.gan_type != 'recurrent_gan'

        self.dataset = DATASETS[cfg.dataset](path=keys[cfg.dataset],
                                             cfg=cfg,
                                             img_size=cfg.img_size)
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=cfg.batch_size,
                                     shuffle=shuffle,
                                     num_workers=cfg.num_workers,
                                     pin_memory=True,
                                     drop_last=True)

        if cfg.dataset == 'codraw':
            self.dataloader.collate_fn = codraw_dataset.collate_data
        elif cfg.dataset == 'iclevr':
            self.dataloader.collate_fn = clevr_dataset.collate_data

        self.visualizer = VisdomPlotter(env_name=cfg.exp_name,
                                        server=cfg.vis_server)
        self.logger = None

        self.cfg = cfg
Example #2
0
    def __init__(self, cfg):
        img_path = os.path.join(cfg.log_path, cfg.exp_name, 'train_images_*')
        if glob.glob(img_path):
            raise Exception('all directories with name train_images_* under '
                            'the experiment directory need to be removed')
        path = os.path.join(cfg.log_path, cfg.exp_name)

        shuffle = False

        print(keys[cfg.dataset])
        self.dataset = DATASETS[cfg.dataset](path=keys[cfg.dataset],
                                             cfg=cfg,
                                             img_size=cfg.img_size)
        # update the cfg's vocab_size
        cfg.vocab_size = self.dataset.vocab_size

        self.dataloader = DataLoader(self.dataset,
                                     batch_size=cfg.batch_size,
                                     shuffle=shuffle,
                                     num_workers=cfg.num_workers,
                                     pin_memory=True,
                                     drop_last=True)

        if cfg.dataset in ['codraw', 'codrawDialog']:
            self.dataloader.collate_fn = codraw_dataset.collate_data
        elif cfg.dataset == 'iclevr':
            self.dataloader.collate_fn = clevr_dataset.collate_data
        elif cfg.dataset in [
                "gandraw", "gandraw_clean", "gandraw_64", "gandraw_64_DA"
        ]:
            self.dataloader.collate_fn = gandraw_dataset.collate_data

        self.model = MODELS[cfg.gan_type](cfg)
        # Added by Mingyang
        #         if cfg.gan_type == "recurrent_gan_mingyang":
        #           print("Wrap DataParallel Around the whole model")
        #           self.model = DataParallel(self.model)

        self.model.save_model(path, 0, 0)

        if cfg.load_snapshot is not None:
            print("load the model from: {}".format(cfg.load_snapshot))
            self.model.load_model(cfg.load_snapshot)

        self.visualizer = VisdomPlotter(env_name=cfg.exp_name,
                                        server=cfg.vis_server)
        self.logger = None

        self.cfg = cfg
Example #3
0
    def __init__(self, cfg):
        img_path = os.path.join(cfg.log_path,
                                cfg.exp_name,
                                'train_images_*')
        if glob.glob(img_path):
            raise Exception('all directories with name train_images_* under '
                            'the experiment directory need to be removed')
        path = os.path.join(cfg.log_path, cfg.exp_name)
        self.dataset = DATASETS[cfg.dataset](
            path=keys[cfg.dataset], cfg=cfg, img_size=cfg.img_size)
        # update the cfg's vocab_size
        cfg.vocab_size = self.dataset.vocab_size

        self.dataloader = DataLoader(self.dataset,
                                     batch_size=cfg.batch_size,
                                     shuffle=False,
                                     num_workers=cfg.num_workers,
                                     pin_memory=True,
                                     drop_last=True)

        if cfg.dataset in ['codraw', 'codrawDialog']:
            self.dataloader.collate_fn = codraw_dataset.collate_data
        elif cfg.dataset in ["gandraw", "gandraw_clean", "gandraw_64", "gandraw_64_DA"]:
            self.dataloader.collate_fn = gandraw_dataset.collate_data

        ####################Load the Model###################
        assert cfg.gan_type == "recurrent_gan_teller", "To run a teller trainer, you will need to use 'recurrent_gan_teller' as gan_type"
        self.model = MODELS[cfg.gan_type](cfg)
        self.model.save_model(path, 0, 0)
        #####################################################

        # Launch the visualizer
        self.visualizer = VisdomPlotter(
            env_name=cfg.exp_name, server=cfg.vis_server)
        self.logger = None

        self.cfg = cfg
from torch import nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from tqdm import tqdm

from geneva.models.object_localizer import inception_v3
from geneva.utils.config import keys
from geneva.utils.visualize import VisdomPlotter


model_urls = {
    # Inception v3 ported from TensorFlow
    'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}

viz = VisdomPlotter(env_name='inception_localizer', server='http://localhost')


class Inception3ObjectLocalizer(nn.Module):
    def __init__(self, num_objects=58, pretrained=True, num_coords=3):
        super().__init__()
        self.inception3 = inception_v3(pretrained=pretrained)
        self.inception3.fc = nn.Linear(self.inception3.fc.in_features, 512)

        self.detector = nn.Sequential(nn.Linear(512, 256),
                                      nn.Linear(256, num_objects),
                                      nn.Sigmoid())
        self.localizer = nn.Sequential(nn.Linear(1024, 512),
                                       nn.Linear(512, num_objects * num_coords))
        self.num_objects = num_objects
        self.num_coords = num_coords
Example #5
0
        del tester
        torch.cuda.empty_cache()
        metrics_report = dict()
        if self.cfg.metric_inception_objects:
            io_jss, io_ap, io_ar, io_af1, io_cs, io_gs = \
                report_inception_objects_score(self.visualizer,
                                               self.logger,
                                               iteration,
                                               self.cfg.results_path,
                                               keys[self.cfg.dataset + '_inception_objects'],
                                               keys[self.cfg.val_dataset],
                                               self.cfg.dataset)

            metrics_report['jaccard'] = io_jss
            metrics_report['precision'] = io_ap
            metrics_report['recall'] = io_ar
            metrics_report['f1'] = io_af1
            metrics_report['cossim'] = io_cs
            metrics_report['relsim'] = io_gs
            
        return metrics_report


if __name__ == '__main__':
    cfg = parse_config()
    visualizer = VisdomPlotter(env_name=cfg.exp_name)
    logger = None
    dataset = cfg.dataset
    evaluator = Evaluator(cfg, visualizer, logger, dataset)
    evaluator.evaluate()
Example #6
0
from geneva.data.datasets import DATASETS
from geneva.evaluation.evaluate import Evaluator
from geneva.utils.config import keys, parse_config
from geneva.utils.visualize import VisdomPlotter
from geneva.models.models import MODELS
from geneva.data import gandraw_dataset

import time
if __name__ == '__main__':
    config_file = "example_args/gandraw_args.json"

    with open(config_file, 'r') as f:
        cfg = json.load(f)

    cfg = easydict.EasyDict(cfg)
    best_iteration = 1500  #Manually define

    #Initialize the evaluator
    visualizer = VisdomPlotter(env_name=cfg.exp_name, server=cfg.vis_server)
    logger = None

    evaluator = Evaluator.factory(cfg, visualizer, logger, visualize_images=[])
    metric_report = evaluator.evaluate(best_iteration, use_test=True)

    print("evaluation results for iter: {} on test data: \n".format(
        best_iteration))
    for key, value in metric_report.items():
        print("{metric_name}: {metric_value}; \n".format(metric_name=key,
                                                         metric_value=value))