Beispiel #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda",
                        default=False,
                        action="store_true",
                        help="Enable cuda computation")
    args = parser.parse_args()

    device = torch.device("cuda" if args.cuda else "cpu")
    envs = [
        InputWrapper(gym.make(name))
        for name in ("Breakout-v0", "AirRaid-v0", "Pong-v0")
    ]
    input_shape = envs[0].observation_space.shape

    net_discr = Discriminator(input_shape=input_shape).to(device)
    net_gener = Generator(output_shape=input_shape).to(device)

    objective = nn.BCELoss()
    gen_optimizer = optim.Adam(params=net_gener.parameters(),
                               lr=LEARNING_RATE,
                               betas=(0.5, 0.999))
    dis_optimizer = optim.Adam(params=net_discr.parameters(),
                               lr=LEARNING_RATE,
                               betas=(0.5, 0.999))

    true_labels_v = torch.ones(BATCH_SIZE, device=device)
    fake_labels_v = torch.zeros(BATCH_SIZE, device=device)

    def process_batch(trainer, batch):
        gen_input_v = torch.FloatTensor(BATCH_SIZE, LATENT_VECTOR_SIZE, 1, 1)
        gen_input_v.normal_(0, 1)
        gen_input_v = gen_input_v.to(device)
        batch_v = batch.to(device)
        gen_output_v = net_gener(gen_input_v)

        # train discriminator
        dis_optimizer.zero_grad()
        dis_output_true_v = net_discr(batch_v)
        dis_output_fake_v = net_discr(gen_output_v.detach())
        dis_loss = objective(dis_output_true_v, true_labels_v) + objective(
            dis_output_fake_v, fake_labels_v)
        dis_loss.backward()
        dis_optimizer.step()

        # train generator
        gen_optimizer.zero_grad()
        dis_output_v = net_discr(gen_output_v)
        gen_loss = objective(dis_output_v, true_labels_v)
        gen_loss.backward()
        gen_optimizer.step()

        if trainer.state.iteration % SAVE_IMAGE_EVERY_ITER == 0:
            fake_img = vutils.make_grid(gen_output_v.data[:64], normalize=True)
            trainer.tb.writer.add_image("fake", fake_img,
                                        trainer.state.iteration)
            real_img = vutils.make_grid(batch_v.data[:64], normalize=True)
            trainer.tb.writer.add_image("real", real_img,
                                        trainer.state.iteration)
            trainer.tb.writer.flush()
        return dis_loss.item(), gen_loss.item()

    engine = Engine(process_batch)
    tb = tb_logger.TensorboardLogger(log_dir=None)
    engine.tb = tb
    RunningAverage(output_transform=lambda out: out[1]).attach(
        engine, "avg_loss_gen")
    RunningAverage(output_transform=lambda out: out[0]).attach(
        engine, "avg_loss_dis")

    handler = tb_logger.OutputHandler(
        tag="train", metric_names=["avg_loss_gen", "avg_loss_dis"])
    tb.attach(engine,
              log_handler=handler,
              event_name=Events.ITERATION_COMPLETED)

    @engine.on(Events.ITERATION_COMPLETED)
    def log_losses(trainer):
        if trainer.state.iteration % REPORT_EVERY_ITER == 0:
            log.info(
                "%d: gen_loss=%f, dis_loss=%f",
                trainer.state.iteration,
                trainer.state.metrics["avg_loss_gen"],
                trainer.state.metrics["avg_loss_dis"],
            )

    engine.run(data=iterate_batches(envs))
        gen_optimizer.step()

        if trainer.state.iteration % SAVE_IMAGE_EVERY_ITER == 0:
            trainer.tb.writer.add_image(
                'fake', vutils.make_grid(gen_output_v.data[:64],
                                         normalize=True),
                trainer.state.iteration)
            trainer.tb.writer.add_image(
                'real', vutils.make_grid(batch_v.data[:64], normalize=True),
                trainer.state.iteration)

        return dis_loss.item(), gen_loss.item()

    engine = Engine(process_batch)
    tb = tb_logger.TensorboardLogger(log_dir=None)
    engine.tb = tb
    RunningAverage(output_transform=lambda out: out[0]).attach(
        engine, 'avg_loss_dis')
    RunningAverage(output_transform=lambda out: out[1]).attach(
        engine, 'avg_loss_gen')
    handler = tb_logger.OutputHandler(
        tag='train', metric_names=['avg_loss_dis', 'avg_loss_gen'])
    tb.attach(engine,
              log_handler=handler,
              event_name=Events.ITERATION_COMPLETED)

    @engine.on(Events.ITERATION_COMPLETED)
    def log_losses(trainer):
        if trainer.state.iteration % REPORT_EVERY_ITER == 0:
            log.info('Iter %d: gen_loss=%.3f, dis_loss=%.3f',
                     trainer.state.iteration,
Beispiel #3
0
def train():
    learning_rate = 0.0001
    save_on_iter_count = 100
    device = "cuda"
    envs = [
        ObservationScaler(gym.make(name))
        for name in ("Breakout-v0", "Pong-v0", "AirRaid-v0")
    ]
    discriminator = Discriminator(img_size=64).to(device)
    generator = Generator().to(device)
    objective = nn.BCELoss()
    discr_optimizer = optim.Adam(params=discriminator.parameters(),
                                 lr=learning_rate,
                                 betas=(0.5, 0.999))
    gen_optimizer = optim.Adam(params=generator.parameters(),
                               lr=learning_rate,
                               betas=(0.5, 0.999))

    def process_batch(trainer, batch):
        batch_size = batch.shape[0]
        gen_input_size = 10

        # get labels and inputs
        generator_inputs = torch.randn(
            (batch_size, gen_input_size, 1, 1)).to(device)
        fake_inputs = generator(generator_inputs).to(device)
        true_inputs = batch.to(device)
        fake_image_labels = torch.zeros((batch_size, )).to(device)
        true_image_labels = torch.ones((batch_size, )).to(device)

        # train discriminator
        discr_optimizer.zero_grad()
        discr_fake_image_output = discriminator(fake_inputs.detach())
        discr_true_image_output = discriminator(true_inputs)

        discr_loss = objective(discr_fake_image_output,
                               fake_image_labels) + objective(
                                   discr_true_image_output, true_image_labels)

        discr_loss.backward()
        discr_optimizer.step()

        # train generator
        gen_optimizer.zero_grad()
        discr_output = discriminator(fake_inputs)
        gen_loss = objective(discr_output, true_image_labels)
        gen_loss.backward()
        gen_optimizer.step()

        # save images
        if trainer.state.iteration % save_on_iter_count == 0:
            fake_img = vutils.make_grid(fake_inputs.data[:64], normalize=True)
            trainer.tb.writer.add_image("fake", fake_img,
                                        trainer.state.iteration)
            real_img = vutils.make_grid(true_inputs.data[:64], normalize=True)
            trainer.tb.writer.add_image("real", real_img,
                                        trainer.state.iteration)
            trainer.tb.writer.flush()
        return discr_loss.item(), gen_loss.item()

    engine = Engine(process_batch)
    tb = tb_logger.TensorboardLogger(log_dir=None)
    engine.tb = tb
    RunningAverage(output_transform=lambda out: out[1]).attach(
        engine, "avg_loss_gen")
    RunningAverage(output_transform=lambda out: out[0]).attach(
        engine, "avg_loss_dis")

    handler = tb_logger.OutputHandler(
        tag="train", metric_names=["avg_loss_gen", "avg_loss_dis"])
    tb.attach(engine,
              log_handler=handler,
              event_name=Events.ITERATION_COMPLETED)

    @engine.on(Events.ITERATION_COMPLETED(every=100))
    def log_training_loss(engine):
        print(f"Epoch[{engine.state.iteration}] Loss:", engine.state.output)

    engine.run(data=generate_batch(envs))