コード例 #1
0
 def _wrap_distributed(self):
     """Wrap modules with distributed wrapper when requested."""
     if not self.distributed_launch and not self.data_parallel_backend:
         return
     elif self.distributed_launch:
         for name, module in self.modules.items():
             if any(p.requires_grad for p in module.parameters()):
                 # for ddp, all module must run on same GPU
                 module = SyncBatchNorm.convert_sync_batchnorm(module)
                 module = DDP(module, device_ids=[self.device])
                 self.modules[name] = module
     else:
         # data_parallel_backend
         for name, module in self.modules.items():
             if any(p.requires_grad for p in module.parameters()):
                 # if distributed_count = -1 then use all gpus
                 # otherwise, specify the set of gpu to use
                 if self.data_parallel_count == -1:
                     module = DP(module)
                 else:
                     module = DP(
                         module,
                         [i for i in range(self.data_parallel_count)],
                     )
                 self.modules[name] = module
コード例 #2
0
    def build_model(self):
        ########## Networks ##########
        self.enc = DP(image_encoder()).to(self.device)
        self.gen = DP(generator(seg_channel=self.seg_channel)).to(self.device)
        self.disORI = DP(discriminator(down_scale=1)).to(self.device)
        self.disHAL = DP(discriminator(down_scale=2)).to(self.device)
        # self.disQUA = DP(discriminator(down_scale=4)).to(self.device)

        ########## Init Networks with Xavier normal ##########
        self.enc.apply(networks._init_weights)
        self.gen.apply(networks._init_weights)
        self.disORI.apply(networks._init_weights)
        self.disHAL.apply(networks._init_weights)
        # self.disQUA.apply(networks._init_weights)

        ########## Loss ##########
        self.KLloss = KL_loss(self.device)
        self.GAN_D_loss = GAN_D_loss(self.device,
                                     GAN_D_loss_type=self.GAN_D_loss_type)
        self.GAN_G_loss = GAN_G_loss(self.device,
                                     GAN_G_loss_type=self.GAN_G_loss_type)
        self.FMloss = FM_loss(self.device)
        if self.use_vgg:
            self.VGGloss = VGG_loss(self.device)

        ########## Optimizer ##########
        self.G_optim = torch.optim.Adam(list(self.gen.parameters()) +
                                        list(self.enc.parameters()),
                                        lr=self.lr_G,
                                        betas=(self.beta_1, self.beta_2))
        self.G_lambda = lambda epoch: max(self.end_lr, (
            epoch - self.start_annealing_epoch) * (self.lr_G - self.end_lr) /
                                          (self.start_annealing_epoch - self.
                                           end_annealing_epoch) + self.lr_G)
        self.G_optim_sch = torch.optim.lr_scheduler.LambdaLR(
            self.G_optim, lr_lambda=self.G_lambda)

        self.D_optim = torch.optim.Adam(
            list(self.disORI.parameters()) + list(self.disHAL.parameters())
            # + list(self.disQUA.parameters())
            ,
            lr=self.lr_D,
            betas=(self.beta_1, self.beta_2))
        self.D_lambda = lambda epoch: max(self.end_lr, (
            epoch - self.start_annealing_epoch) * (self.lr_D - self.end_lr) /
                                          (self.start_annealing_epoch - self.
                                           end_annealing_epoch) + self.lr_D)
        self.D_optim_sch = torch.optim.lr_scheduler.LambdaLR(
            self.D_optim, lr_lambda=self.D_lambda)
コード例 #3
0
ファイル: base_task.py プロジェクト: mtli/llcv
    def __init__(self, args, loader, is_train):
        self.loader = loader
        self.dataset = loader.dataset
        self.is_train = is_train
        self.device = args.device
        self.gather = False
        self.gpu_gather = args.gpu_gather
        self.resume_epoch = 0
        self.has_val_score = False
        self.exp_dir = args.exp_dir
        if self.is_train:
            self.last_lr = args.lr
            self.lr_update_per_epoch = args.lr_update_per_epoch

        self.model = build_model(args, self.dataset)
        logging.debug(str(self.model))
        logging.debug(
            f'Total number of parameters: {sum([p.numel() for p in self.model.parameters()])}'
        )

        self.rank = dist_get_rank()
        if self.rank >= 0:
            self.device = torch.cuda.current_device(
            ) if args.use_cuda else 'cpu'
            self.model = self.model.to(self.device)
            self.model = DDP(
                self.model,
                [self.device] if args.use_cuda else None,
                find_unused_parameters=True,
            )
        else:
            if args.use_cuda:
                if torch.cuda.device_count() > 1:
                    self.model = DP(self.model)
                self.model = self.model.to(self.device)
        self.output_device = args.device if args.gpu_gather else 'cpu'

        if is_train:
            logging.debug(
                f'Optimizer: {args.optim} with base learning rate {args.lr:.6g}'
            )
            self.set_optim(args)
            self.set_lr_schedule(args)

        self.auto_load(args)
コード例 #4
0
 def _wrap_distributed(self):
     """Wrap modules with distributed wrapper when requested."""
     if not self.distributed_launch and not self.data_parallel_backend:
         return
     elif self.distributed_launch:
         for name, module in self.modules.items():
             if any(p.requires_grad for p in module.parameters()):
                 module = SyncBatchNorm.convert_sync_batchnorm(module)
                 module = DDP(
                     module,
                     device_ids=[self.device],
                     find_unused_parameters=self.find_unused_parameters,
                 )
                 self.modules[name] = module
     else:
         # data_parallel_backend
         for name, module in self.modules.items():
             if any(p.requires_grad for p in module.parameters()):
                 module = DP(module)
                 self.modules[name] = module
コード例 #5
0
ファイル: kg.py プロジェクト: vasu0403/aqa
    if torch.cuda.is_available():
        DEVICE = torch.device('cuda')
        device_ids = list(range(torch.cuda.device_count()))
        gpus = len(device_ids)
        print('GPU detected')
    else:
        DEVICE = torch.device("cpu")
        device_ids = -1
        print('No GPU. switching to CPU')

    model = 'textattack/albert-base-v2-MRPC'
    tokenizer = AutoTokenizer.from_pretrained(model)
    model = AutoModelForSequenceClassification.from_pretrained(model)
    if not device_ids == -1:
        print('Porting model to CUDA...')
        model = DP(model, device_ids=device_ids)
        model.to(f'cuda:{model.device_ids[0]}')
    model.eval()

    if args.dtypes is None:
        dtypes = ['mini'] if args.debug else ['validation', 'train']
        if not dataset.name == 'squad' and not args.debug:
            dtypes.append('test')
    else:
        dtypes = args.dtypes.split(',')
    print('Running KG builder for {} sets'.format(', '.join(dtypes)))

    results = []
    for dtype in dtypes:
        start_time = time()
        oie_fn = os.path.join(data_dir, 'oie_data', 'predictions_{}.json'.format(dtype))
コード例 #6
0
ファイル: models.py プロジェクト: e259f381/RTG
    def __init__(
            self,
            env: MultiAgentVecEnv,
            device="cpu",
            dtype=torch.float32,
            memory_units=256,
            out_features=256,
            model="default",
            data_parallel=False,
            lstm_mode='cat',
            nan_check=False,
            input_shape=None,  # if set overrides environment observation shape
    ):

        input_shape = input_shape or env.observation_space.shape

        assert env.observation_space.dtype == np.uint8, "Observation space should be 8bit integer"

        self.lstm_mode = lstm_mode
        if self.lstm_mode == 'residual':
            assert memory_units == out_features
            self.encoder_output_features = out_features
        elif self.lstm_mode == 'off':
            self.encoder_output_features = out_features
        elif self.lstm_mode == 'on':
            self.encoder_output_features = memory_units
        elif self.lstm_mode == 'cat':
            self.encoder_output_features = memory_units + out_features
        else:
            raise ValueError(f"invalid lstm mode {self.lstm_mode}.")

        super().__init__()

        # environment is channels last, but we work with channels first.
        self.input_shape = tuple(input_shape)

        self.n_actions = env.action_space.n
        self.device = device
        self.dtype = dtype
        self.lstm_mode = lstm_mode
        self.nan_check = nan_check

        if model.lower() == "default":
            self.encoder = DefaultEncoder(self.input_shape,
                                          out_features=out_features)
        elif model.lower() == "fast":
            self.encoder = FastEncoder(self.input_shape,
                                       out_features=out_features)
        else:
            raise Exception(
                f"Invalid model {model}, expected [default|fast|global]")

        self.encoder_features = self.encoder.out_features
        self.memory_units = memory_units

        # note, agents are AI controlled, players maybe scripted, but an AI still needs to predict their behaviour.
        self.n_agents = env.total_agents
        self.n_players = env.max_players

        # hardcode this to 3 for simplicity, could use env.max_roles, but then it might change from game to game
        self.n_roles = 3

        # memory units
        self.lstm = torch.nn.LSTM(input_size=self.encoder_features,
                                  hidden_size=self.memory_units,
                                  num_layers=1,
                                  batch_first=False,
                                  dropout=0)

        self.encoder_type = type(
            self.encoder
        )  # the type of encoder before and data_parallel was applied
        if data_parallel:
            # enable multi gpu :)
            print(
                f" -enabling {utils.Color.OKGREEN}Multi-GPU{utils.Color.ENDC} support"
            )
            self.encoder = DP(self.encoder)