def test_map_location_with_cuda(): fn = map_location('cuda:0') dummy = Mock() dummy.cuda = Mock() fn(dummy, '') dummy.cuda.assert_called_with('cuda:0')
def test_map_location_with_cpu(): assert map_location('cpu:0') == 'cpu'
def load_model(self, fname): chkpt = torch.load(fname, map_location=map_location(self.device)) set_state_dict(self, chkpt)