Ejemplo n.º 1
0
def test_epoch(dataloader, network, loss, loggers):
    logger, vizlogger = loggers

    printer = Printer(N=len(dataloader))
    logger.set_mode("test")
    mean_loss = 0.

    with torch.no_grad():
        for iteration, data in enumerate(dataloader):

            input, target = data

            if torch.cuda.is_available():
                input = input.cuda()
                target = target.cuda()

            output = network.forward(input)
            l = loss(output, target)

            stats = {"loss": l.data.cpu().numpy()}
            mean_loss += l.data.cpu().numpy()

            printer.print(stats, iteration)
            logger.log(stats, iteration)
            vizlogger.plot_steps(logger.get_data())

        vizlogger.plot_images(target.cpu().detach().numpy(),
                              output.cpu().detach().numpy())
    print('Loss: %.4f' % (mean_loss / iteration))
Ejemplo n.º 2
0
def train_epoch(dataloader, network, optimizer, loss, loggers):
    logger, vizlogger = loggers

    printer = Printer(N=len(dataloader))
    logger.set_mode("train")

    for iteration, data in enumerate(dataloader):
        optimizer.zero_grad()

        input, target = data

        if torch.cuda.is_available():
            input = input.cuda()
            target = target.cuda()

        output = network.forward(input)
        l = loss(output, target)
        #print(l)
        stats = {"loss":l.data.cpu().numpy()}

        l.backward()
        optimizer.step()

        printer.print(stats, iteration)
        logger.log(stats, iteration)
        vizlogger.plot_steps(logger.get_data())
Ejemplo n.º 3
0
def init_logger(args, algorithm):
    algorithm_dir = algorithm.get_log_folder_name()
    time_str = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    algorithm_path = f'{algorithm_dir}_{time_str}'
    logger_path = os.path.join(args.logdir, f'{args.env_name}', algorithm_path)
    # logger_path = os.path.join(os.path.join(f'{args.env_name}', args.logdir), time_str)
    logger = Logger(path=logger_path) if args.logdir != '' else Printer()
    summary_writer = SummaryWriter(log_dir=logger_path)
    return logger, summary_writer
Ejemplo n.º 4
0
    def __init__(self, args):
        self.args = args
        self.replay = None
        self.summary = None
        self.logger = Printer()

        self.critic_workers = []
        self.actor_workers = []
        self.population = None

        self.critic_worker_num = 0
        self.actor_worker_num = 0
        self.individual_dim = 0

        env, self.state_dim, self.action_dim = init_gym_from_args(args)
        env.close()
Ejemplo n.º 5
0
 def __init__(self, logger=None):
     self.workers = {}
     self.jobs = {}
     self.logger = logger if logger is not None else Printer()