Example #1
0
      model.step = old_step
    if torch.cuda.is_available() and \
       self.options.gpu is not None and \
       self.options.gpu >= 0:
      model.cuda(self.options.gpu)

    return model

  def _on_get_args(self, *args, **kwargs):
    warnings.warn(
      ('_on_get_args is deprecated, get rid of this as soon as old '
       'model files are no longer needed'),
      DeprecationWarning)


_load_env_logger = logging.stderr_color_mt('rlpytorch.model_loader.load_env')


def load_env(
    envs,
    num_models=None,
    overrides=None,
    additional_to_load=None):
  """Load envs.

  Envs will be specified as environment variables. Specifically, the
  environment variables ``game``, ``model_file`` and ``model`` are
  required.

  ``additional_to_load`` is a dict with the following format:
Example #2
0
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch.nn as nn
from torch.autograd import Variable

import elf.logging as logging
from elf.options import auto_import_options, PyOptionSpec
from rlpytorch.trainer.timer import RLTimer

_logger_factory = logging.IndexedLoggerFactory(
    lambda name: logging.stderr_color_mt(name))


class MyOptim(object):
    @classmethod
    def get_option_spec(cls):
        spec = PyOptionSpec()
        spec.addBoolOption('backprop', 'Whether to backprop the total loss',
                           True)
        return spec

    @auto_import_options
    def __init__(self, option_map):
        self.policy_loss = nn.KLDivLoss().cuda()
        self.value_loss = nn.MSELoss().cuda()
        self.logger = _logger_factory.makeLogger('elfgames.go.MCTSPrediction-',
                                                 '')
Example #3
0
# LICENSE file in the root directory of this source tree.

import importlib
import pprint
import random
import time
import warnings

from elf.options import import_options, PyOptionSpec
from elf import logging
from .model_interface import ModelInterface
from .sampler import Sampler
from .utils.fp16_utils import FP16Model

_logger_factory = logging.IndexedLoggerFactory(
    lambda name: logging.stderr_color_mt(name))


def load_module(mod):
    """Load a python module."""
    module = importlib.import_module(mod)
    print(module, mod)
    return module


class ModelLoader(object):
    """Class to load a previously saved model."""
    @classmethod
    def get_option_spec(cls, model_class=None, model_idx=None):
        spec = PyOptionSpec()
        spec.addStrOption('load', 'load model', '')
Example #4
0
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch.nn as nn
from torch.autograd import Variable

import elf.logging as logging
from elf.options import auto_import_options, PyOptionSpec
from rlpytorch.trainer.timer import RLTimer


_logger_factory = logging.IndexedLoggerFactory(
    lambda name: logging.stderr_color_mt(name))


class MCTSPrediction(object):
    @classmethod
    def get_option_spec(cls):
        spec = PyOptionSpec()
        spec.addBoolOption(
            'backprop',
            'Whether to backprop the total loss',
            True)
        return spec

    @auto_import_options
    def __init__(self, option_map):
        self.policy_loss = nn.KLDivLoss().cuda()