Esempio n. 1
0
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config):
    metrics = {
        "images": Last(),
        "loss": Mean(),
        "lr": Last(),
    }

    # loop over batches ################################################################################################
    model.train()
    for images, meta, targets in tqdm(
        data_loader,
        desc="fold {}, epoch {}/{}, train".format(config.fold, epoch, config.train.epochs),
    ):
        images, meta, targets = (
            images.to(DEVICE),
            {k: meta[k].to(DEVICE) for k in meta},
            targets.to(DEVICE),
        )
        # images, targets = mix_up(images, targets, alpha=1.)

        logits = model(images, meta)
        loss = compute_loss(input=logits, target=targets, config=config)

        metrics["images"].update(images.data.cpu())
        metrics["loss"].update(loss.data.cpu())
        metrics["lr"].update(np.squeeze(scheduler.get_last_lr()))

        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        scheduler.step()

    # compute metrics ##################################################################################################
    with torch.no_grad():
        metrics = {k: metrics[k].compute_and_reset() for k in metrics}

        writer = SummaryWriter(os.path.join(config.experiment_path, "train"))
        writer.add_image(
            "images",
            torchvision.utils.make_grid(
                metrics["images"], nrow=compute_nrow(metrics["images"]), normalize=True
            ),
            global_step=epoch,
        )
        writer.add_scalar("loss", metrics["loss"], global_step=epoch)
        writer.add_scalar("lr", metrics["lr"], global_step=epoch)

        writer.flush()
        writer.close()
Esempio n. 2
0
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config):
    metrics = {
        "loss": Mean(),
        "lr": Last(),
        "p/norm": Mean(),
        "z/norm": Mean(),
        "p/std": Mean(),
        "z/std": Mean(),
    }

    model.train()
    for images1, images2 in tqdm(
            data_loader,
            desc="epoch {}/{}, train".format(epoch, config.epochs),
    ):
        images1, images2 = images1.to(DEVICE), images2.to(DEVICE)

        p1, z1 = model(images1)
        p2, z2 = model(images2)

        loss = (sum([
            compute_loss(p=p1, z=z2.detach()),
            compute_loss(p=p2, z=z1.detach()),
        ]) / 2)

        metrics["loss"].update(loss.detach())
        metrics["lr"].update(np.squeeze(scheduler.get_last_lr()))

        for k, v in [("p", p1), ("p", p2), ("z", z1), ("z", z2)]:
            v = v.detach()
            metrics["{}/norm".format(k)].update(v.norm(dim=1))
            v = F.normalize(v, dim=1)
            metrics["{}/std".format(k)].update(v.std(dim=0))

        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        scheduler.step()

    writer = SummaryWriter(os.path.join(config.experiment_path, "train"))
    with torch.no_grad():
        images = torch.cat([images1, images2], 3)

        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)
        writer.add_image(
            "images",
            torchvision.utils.make_grid(images[:16], nrow=1, normalize=True),
            global_step=epoch,
        )

    writer.flush()
    writer.close()
Esempio n. 3
0
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config):
    metrics = {
        "loss": Mean(),
        "lr": Last(),
    }

    model.train()
    for images, targets in tqdm(data_loader,
                                desc="epoch {}/{}, train".format(
                                    epoch, config.epochs)):
        images, targets = images.to(DEVICE), targets.to(DEVICE)

        logits = model(images)
        loss = F.cross_entropy(input=logits, target=targets, reduction="none")

        metrics["loss"].update(loss.data.cpu().numpy())
        metrics["lr"].update(np.squeeze(scheduler.get_lr()))

        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        scheduler.step()

    writer = SummaryWriter(os.path.join(config.experiment_path, "train"))
    with torch.no_grad():
        for k in metrics:
            writer.add_scalar(k,
                              metrics[k].compute_and_reset(),
                              global_step=epoch)
        writer.add_image(
            "images",
            torchvision.utils.make_grid(denormalize(images),
                                        nrow=compute_nrow(images),
                                        normalize=True),
            global_step=epoch,
        )
        writer.add_histogram("params",
                             flatten_weights(model.parameters()),
                             global_step=epoch)

    writer.flush()
    writer.close()
Esempio n. 4
0
def train_epoch(model, optimizer, scheduler, data_loader, box_coder,
                class_names, epoch, config):
    metrics = {
        "loss": Mean(),
        "loss/class": Mean(),
        "loss/loc": Mean(),
        "loss/cent": Mean(),
        "learning_rate": Last(),
    }

    model.train()
    optimizer.zero_grad()
    for i, batch in tqdm(enumerate(data_loader, 1),
                         desc="epoch {} train".format(epoch),
                         total=len(data_loader)):
        images, targets, dets_true = apply_recursively(lambda x: x.to(DEVICE),
                                                       batch)

        output = model(images)

        loss_dict = compute_loss(input=output, target=targets)
        loss = sum(loss_dict.values())

        metrics["loss"].update(loss.data.cpu())
        for k in loss_dict:
            metrics["loss/{}".format(k)].update(loss_dict[k].data.cpu())
        metrics["learning_rate"].update(np.squeeze(scheduler.get_lr()))

        (loss.mean() / config.train.acc_steps).backward()
        if i % config.train.acc_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        scheduler.step()

    with torch.no_grad():
        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        writer = SummaryWriter(os.path.join(config.experiment_path, "train"))

        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)

        images = denormalize(images, mean=MEAN, std=STD)

        dets_true = [
            box_coder.decode(foreground_binary_coding(c, 80), r, s,
                             images.size()[2:]) for c, r, s in zip(*targets)
        ]
        dets_pred = [
            box_coder.decode(c.sigmoid(), r, s.sigmoid(),
                             images.size()[2:]) for c, r, s in zip(*output)
        ]

        true = [
            draw_boxes(i, d, class_names) for i, d in zip(images, dets_true)
        ]
        pred = [
            draw_boxes(i, d, class_names) for i, d in zip(images, dets_pred)
        ]

        writer.add_image("detections/true",
                         torchvision.utils.make_grid(true, nrow=4),
                         global_step=epoch)
        writer.add_image("detections/pred",
                         torchvision.utils.make_grid(pred, nrow=4),
                         global_step=epoch)

        writer.flush()
        writer.close()
Esempio n. 5
0
def main():
    args = build_parser().parse_args()
    config = build_default_config()
    config.merge_from_file(args.config_path)
    config.experiment_path = args.experiment_path
    config.render = not args.no_render
    config.freeze()
    del args

    writer = SummaryWriter(config.experiment_path)

    seed_torch(config.seed)
    env = VecEnv([lambda: build_env(config) for _ in range(config.workers)])
    if config.render:
        env = wrappers.TensorboardBatchMonitor(env, writer, config.log_interval)
    env = wrappers.torch.Torch(env, device=DEVICE)
    env.seed(config.seed)

    policy_model = ModelDQN(config.model, env.observation_space, env.action_space).to(DEVICE)
    target_model = ModelDQN(config.model, env.observation_space, env.action_space).to(DEVICE)
    target_model.load_state_dict(policy_model.state_dict())
    optimizer = build_optimizer(config.opt, policy_model.parameters())
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.episodes)

    metrics = {
        "loss": Mean(),
        "lr": Last(),
        "eps": FPS(),
        "ep/length": Mean(),
        "ep/reward": Mean(),
    }

    # ==================================================================================================================
    # training loop
    policy_model.train()
    target_model.eval()
    episode = 0
    s = env.reset()
    e_base = 0.95
    e_step = np.exp(np.log(0.05 / e_base) / config.episodes)

    bar = tqdm(total=config.episodes, desc="training")
    history = History()
    while episode < config.episodes:
        with torch.no_grad():
            for _ in range(config.horizon):
                av = policy_model(s)
                a = sample_action(av, e_base * e_step ** episode)
                s_prime, r, d, meta = env.step(a)
                history.append(
                    state=s.cpu(),
                    action=a.cpu(),
                    reward=r.cpu(),
                    done=d.cpu(),
                    state_prime=s_prime.cpu(),
                )
                # history.append(state=s, action=a, reward=r, done=d, state_prime=s_prime)
                s = s_prime

                (indices,) = torch.where(d)
                for i in indices:
                    metrics["eps"].update(1)
                    metrics["ep/length"].update(meta[i]["episode"]["l"])
                    metrics["ep/reward"].update(meta[i]["episode"]["r"])
                    episode += 1
                    scheduler.step()
                    bar.update(1)

                    if episode % 10 == 0:
                        target_model.load_state_dict(policy_model.state_dict())

                    if episode % config.log_interval == 0 and episode > 0:
                        for k in metrics:
                            writer.add_scalar(
                                k, metrics[k].compute_and_reset(), global_step=episode
                            )
                        writer.add_scalar("e", e_base * e_step ** episode, global_step=episode)
                        writer.add_histogram(
                            "rollout/action", rollout.actions, global_step=episode
                        )
                        writer.add_histogram(
                            "rollout/reward", rollout.rewards, global_step=episode
                        )
                        writer.add_histogram("rollout/return", returns, global_step=episode)
                        writer.add_histogram(
                            "rollout/action_value", action_values, global_step=episode
                        )

        rollout = history.full_rollout()
        action_values = policy_model(rollout.states)
        action_values = action_values * one_hot(rollout.actions, action_values.size(-1))
        action_values = action_values.sum(-1)
        with torch.no_grad():
            action_values_prime = target_model(rollout.states_prime)
            action_values_prime, _ = action_values_prime.detach().max(-1)
        returns = one_step_discounted_return(
            rollout.rewards, action_values_prime, rollout.dones, gamma=config.gamma
        )

        # critic
        errors = returns - action_values
        critic_loss = errors ** 2

        loss = (critic_loss * 0.5).mean(1)

        metrics["loss"].update(loss.data.cpu().numpy())
        metrics["lr"].update(np.squeeze(scheduler.get_lr()))

        # training
        optimizer.zero_grad()
        loss.mean().backward()
        nn.utils.clip_grad_norm_(policy_model.parameters(), 0.5)
        optimizer.step()

    bar.close()
    env.close()
Esempio n. 6
0
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config):
    metrics = {
        "loss/x": Mean(),
        "loss/u": Mean(),
        "weight/u": Last(),
        "lr": Last(),
    }

    model.train()
    for (images_x, targets_x), (images_u_0, images_u_1) in tqdm(
            data_loader,
            desc="epoch {}/{}, train".format(epoch, config.epochs)):
        # prepare data #################################################################################################
        images_x, targets_x, images_u_0, images_u_1 = (
            images_x.to(DEVICE),
            targets_x.to(DEVICE),
            images_u_0.to(DEVICE),
            images_u_1.to(DEVICE),
        )
        targets_x = one_hot(targets_x, NUM_CLASSES)

        # mix-match ####################################################################################################
        with torch.no_grad():
            (images_x,
             targets_x), (images_u,
                          targets_u) = mix_match(x=(images_x, targets_x),
                                                 u=(images_u_0, images_u_1),
                                                 model=model,
                                                 config=config)

        probs_x, probs_u = model(torch.cat([images_x, images_u])).split(
            [images_x.size(0), images_u.size(0)])

        # x ############################################################################################################
        loss_x = compute_loss_x(input=probs_x, target=targets_x)
        metrics["loss/x"].update(loss_x.data.cpu().numpy())

        # u ############################################################################################################
        loss_u = compute_loss_u(input=probs_u, target=targets_u)
        metrics["loss/u"].update(loss_u.data.cpu().numpy())

        # opt step #####################################################################################################
        metrics["lr"].update(np.squeeze(scheduler.get_last_lr()))
        weight_u = config.train.mix_match.weight_u * min(
            (epoch - 1) / config.epochs_warmup, 1.0)
        metrics["weight/u"].update(weight_u)

        optimizer.zero_grad()
        (loss_x.mean() + weight_u * loss_u.mean()).backward()
        optimizer.step()
        scheduler.step()

    if epoch % config.log_interval != 0:
        return

    writer = SummaryWriter(os.path.join(config.experiment_path, "train"))
    with torch.no_grad():
        for k in metrics:
            writer.add_scalar(k,
                              metrics[k].compute_and_reset(),
                              global_step=epoch)
        writer.add_image(
            "images_x",
            torchvision.utils.make_grid(images_x,
                                        nrow=compute_nrow(images_x),
                                        normalize=True),
            global_step=epoch,
        )
        writer.add_image(
            "images_u",
            torchvision.utils.make_grid(images_u,
                                        nrow=compute_nrow(images_u),
                                        normalize=True),
            global_step=epoch,
        )

    writer.flush()
    writer.close()
Esempio n. 7
0
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config):
    metrics = {
        "loss": Mean(),
        "lr": Last(),
    }

    model.train()
    for (text, text_mask), (audio, audio_mask) in tqdm(
            data_loader,
            desc="epoch {}/{}, train".format(epoch, config.train.epochs)):
        text, audio, text_mask, audio_mask = [
            x.to(DEVICE) for x in [text, audio, text_mask, audio_mask]
        ]

        output, pre_output, target, target_mask, weight = model(
            text, text_mask, audio, audio_mask)

        loss = masked_mse(output, target, target_mask) + masked_mse(
            pre_output, target, target_mask)

        metrics["loss"].update(loss.data.cpu())
        metrics["lr"].update(np.squeeze(scheduler.get_last_lr()))

        optimizer.zero_grad()
        loss.mean().backward()
        if config.train.clip_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           config.train.clip_grad_norm)
        optimizer.step()
        scheduler.step()

    writer = SummaryWriter(os.path.join(config.experiment_path, "train"))
    with torch.no_grad():
        gl_true = griffin_lim(target, model.spectra)
        gl_pred = griffin_lim(output, model.spectra)
        output, pre_output, target, weight = [
            x.unsqueeze(1) for x in [output, pre_output, target, weight]
        ]
        nrow = compute_nrow(target)

        for k in metrics:
            writer.add_scalar(k,
                              metrics[k].compute_and_reset(),
                              global_step=epoch)
        writer.add_image(
            "target",
            torchvision.utils.make_grid(target, nrow=nrow, normalize=True),
            global_step=epoch,
        )
        writer.add_image(
            "output",
            torchvision.utils.make_grid(output, nrow=nrow, normalize=True),
            global_step=epoch,
        )
        writer.add_image(
            "pre_output",
            torchvision.utils.make_grid(pre_output, nrow=nrow, normalize=True),
            global_step=epoch,
        )
        writer.add_image(
            "weight",
            torchvision.utils.make_grid(weight, nrow=nrow, normalize=True),
            global_step=epoch,
        )
        for i in tqdm(range(min(text.size(0), 4)), desc="writing audio"):
            writer.add_audio("audio/{}".format(i),
                             audio[i],
                             sample_rate=config.sample_rate,
                             global_step=epoch)
            writer.add_audio(
                "griffin-lim-true/{}".format(i),
                gl_true[i],
                sample_rate=config.sample_rate,
                global_step=epoch,
            )
            writer.add_audio(
                "griffin-lim-pred/{}".format(i),
                gl_pred[i],
                sample_rate=config.sample_rate,
                global_step=epoch,
            )

    writer.flush()
    writer.close()
Esempio n. 8
0
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config):
    writer = SummaryWriter(os.path.join(config.experiment_path, "train"))
    metrics = {
        "loss": Mean(),
        "lr": Last(),
    }

    model.train()
    for images, targets in tqdm(
        data_loader, desc="epoch {}/{}, train".format(epoch, config.epochs)
    ):
        images, targets = images.to(DEVICE), targets.to(DEVICE)
        targets = image_one_hot(targets, NUM_CLASSES)

        logits = model(images)
        loss = compute_loss(input=logits, target=targets)

        metrics["loss"].update(loss.data.cpu().numpy())
        metrics["lr"].update(np.squeeze(scheduler.get_lr()))

        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        scheduler.step()

    with torch.no_grad():
        mask_true = F.interpolate(draw_masks(targets.argmax(1, keepdim=True)), scale_factor=1)
        mask_pred = F.interpolate(draw_masks(logits.argmax(1, keepdim=True)), scale_factor=1)

        for k in metrics:
            writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch)
        writer.add_image(
            "images",
            torchvision.utils.make_grid(
                denormalize(images), nrow=compute_nrow(images), normalize=True
            ),
            global_step=epoch,
        )
        writer.add_image(
            "mask_true",
            torchvision.utils.make_grid(mask_true, nrow=compute_nrow(mask_true), normalize=True),
            global_step=epoch,
        )
        writer.add_image(
            "mask_pred",
            torchvision.utils.make_grid(mask_pred, nrow=compute_nrow(mask_pred), normalize=True),
            global_step=epoch,
        )
        writer.add_image(
            "images_true",
            torchvision.utils.make_grid(
                denormalize(images) + mask_true, nrow=compute_nrow(images), normalize=True
            ),
            global_step=epoch,
        )
        writer.add_image(
            "images_pred",
            torchvision.utils.make_grid(
                denormalize(images) + mask_pred, nrow=compute_nrow(images), normalize=True
            ),
            global_step=epoch,
        )

    writer.flush()
    writer.close()
Esempio n. 9
0
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config):
    metrics = {
        "x_loss": Mean(),
        "u_loss": Mean(),
        "u_loss_mask": Mean(),
        "lr": Last(),
    }

    model.train()
    for (x_w_images, x_targets), (u_w_images, u_s_images) in tqdm(
        data_loader, desc="epoch {}/{}, train".format(epoch, config.epochs)
    ):
        x_w_images, x_targets, u_w_images, u_s_images = (
            x_w_images.to(DEVICE),
            x_targets.to(DEVICE),
            u_w_images.to(DEVICE),
            u_s_images.to(DEVICE),
        )

        x_w_logits, u_w_logits, u_s_logits = model(
            torch.cat([x_w_images, u_w_images, u_s_images], 0)
        ).split([x_w_images.size(0), u_w_images.size(0), u_s_images.size(0)])

        # x ############################################################################################################
        x_loss = F.cross_entropy(input=x_w_logits, target=x_targets, reduction="none")
        metrics["x_loss"].update(x_loss.data.cpu().numpy())

        # u ############################################################################################################
        u_loss_mask, u_targets = F.softmax(u_w_logits.detach(), 1).max(1)
        u_loss_mask = (u_loss_mask >= config.train.tau).float()

        u_loss = u_loss_mask * F.cross_entropy(
            input=u_s_logits, target=u_targets, reduction="none"
        )
        metrics["u_loss"].update(u_loss.data.cpu().numpy())
        metrics["u_loss_mask"].update(u_loss_mask.data.cpu().numpy())

        # opt step #####################################################################################################
        metrics["lr"].update(np.squeeze(scheduler.get_lr()))

        optimizer.zero_grad()
        (x_loss.mean() + config.train.u_weight * u_loss.mean()).backward()
        optimizer.step()
        scheduler.step()

    if epoch % config.log_interval != 0:
        return

    writer = SummaryWriter(os.path.join(config.experiment_path, "train"))
    with torch.no_grad():
        for k in metrics:
            writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch)
        writer.add_image(
            "x_w_images",
            torchvision.utils.make_grid(
                denormalize(x_w_images), nrow=compute_nrow(x_w_images), normalize=True
            ),
            global_step=epoch,
        )
        writer.add_image(
            "u_w_images",
            torchvision.utils.make_grid(
                denormalize(u_w_images), nrow=compute_nrow(u_w_images), normalize=True
            ),
            global_step=epoch,
        )
        writer.add_image(
            "u_s_images",
            torchvision.utils.make_grid(
                denormalize(u_s_images), nrow=compute_nrow(u_s_images), normalize=True
            ),
            global_step=epoch,
        )

    writer.flush()
    writer.close()
Esempio n. 10
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)
    del config_path, kwargs

    writer = SummaryWriter(config.experiment_path)

    seed_torch(config.seed)
    env = VecEnv([lambda: build_env(config) for _ in range(config.workers)])
    if config.render:
        env = wrappers.TensorboardBatchMonitor(env, writer,
                                               config.log_interval)
    env = wrappers.Torch(env, dtype=torch.float, device=DEVICE)
    env.seed(config.seed)

    model = Model(config.model, env.observation_space, env.action_space)
    model = model.to(DEVICE)
    if config.restore_path is not None:
        model.load_state_dict(torch.load(config.restore_path))
    optimizer = build_optimizer(config.opt, model.parameters())
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, config.episodes)

    metrics = {
        "loss": Mean(),
        "lr": Last(),
        "eps": FPS(),
        "ep/length": Mean(),
        "ep/return": Mean(),
        "rollout/entropy": Mean(),
    }

    # ==================================================================================================================
    # training loop
    model.train()
    episode = 0
    s = env.reset()

    bar = tqdm(total=config.episodes, desc="training")
    while episode < config.episodes:
        history = History()

        with torch.no_grad():
            for _ in range(config.horizon):
                a, _ = model(s)
                a = a.sample()
                s_prime, r, d, info = env.step(a)
                history.append(state=s,
                               action=a,
                               reward=r,
                               done=d,
                               state_prime=s_prime)
                s = s_prime

                (indices, ) = torch.where(d)
                for i in indices:
                    metrics["eps"].update(1)
                    metrics["ep/length"].update(info[i]["episode"]["l"])
                    metrics["ep/return"].update(info[i]["episode"]["r"])
                    episode += 1
                    scheduler.step()
                    bar.update(1)

                    if episode % config.log_interval == 0 and episode > 0:
                        for k in metrics:
                            writer.add_scalar(k,
                                              metrics[k].compute_and_reset(),
                                              global_step=episode)
                        writer.add_histogram("rollout/action",
                                             rollout.actions,
                                             global_step=episode)
                        writer.add_histogram("rollout/reward",
                                             rollout.rewards,
                                             global_step=episode)
                        writer.add_histogram("rollout/return",
                                             returns,
                                             global_step=episode)
                        writer.add_histogram("rollout/value",
                                             values,
                                             global_step=episode)
                        writer.add_histogram("rollout/advantage",
                                             advantages,
                                             global_step=episode)

                        torch.save(
                            model.state_dict(),
                            os.path.join(config.experiment_path,
                                         "model_{}.pth".format(episode)),
                        )

        rollout = history.full_rollout()
        dist, values = model(rollout.states)
        with torch.no_grad():
            _, value_prime = model(rollout.states_prime[:, -1])
            returns = n_step_bootstrapped_return(rollout.rewards,
                                                 value_prime,
                                                 rollout.dones,
                                                 discount=config.gamma)

        # critic
        errors = returns - values
        critic_loss = errors**2

        # actor
        advantages = errors.detach()
        log_prob = dist.log_prob(rollout.actions)
        entropy = dist.entropy()

        if isinstance(env.action_space, gym.spaces.Box):
            log_prob = log_prob.sum(-1)
            entropy = entropy.sum(-1)
        assert log_prob.dim() == entropy.dim() == 2

        actor_loss = -log_prob * advantages - config.entropy_weight * entropy

        # loss
        loss = (actor_loss + 0.5 * critic_loss).mean(1)

        metrics["loss"].update(loss.data.cpu().numpy())
        metrics["lr"].update(np.squeeze(scheduler.get_lr()))
        metrics["rollout/entropy"].update(dist.entropy().data.cpu().numpy())

        # training
        optimizer.zero_grad()
        loss.mean().backward()
        nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

    bar.close()
    env.close()
Esempio n. 11
0
def train_epoch(model, data_loader, opt_teacher, opt_student, sched_teacher,
                sched_student, epoch, config):
    metrics = {
        "teacher/loss": Mean(),
        "teacher/grad_norm": Mean(),
        "teacher/lr": Last(),
        "student/loss": Mean(),
        "student/grad_norm": Mean(),
        "student/lr": Last(),
    }

    model.train()
    for (x_image, x_target), (u_image, ) in tqdm(
            data_loader,
            desc="epoch {}/{}, train".format(epoch, config.epochs)):
        x_image, x_target, u_image = x_image.to(DEVICE), x_target.to(
            DEVICE), u_image.to(DEVICE)

        with higher.innerloop_ctx(model.student,
                                  opt_student) as (h_model_student,
                                                   h_opt_student):
            # student ##################################################################################################

            loss_student = cross_entropy(input=h_model_student(u_image),
                                         target=model.teacher(u_image)).mean()
            metrics["student/loss"].update(loss_student.data.cpu().numpy())
            metrics["student/lr"].update(
                np.squeeze(sched_student.get_last_lr()))

            def grad_callback(grads):
                metrics["student/grad_norm"].update(
                    grad_norm(grads).data.cpu().numpy())

                return grads

            h_opt_student.step(loss_student.mean(),
                               grad_callback=grad_callback)
            sched_student.step()

            # teacher ##################################################################################################

            loss_teacher = (
                cross_entropy(input=model.teacher(x_image),
                              target=one_hot(x_target, NUM_CLASSES)).mean() +
                cross_entropy(input=h_model_student(x_image),
                              target=one_hot(x_target, NUM_CLASSES)).mean())

            metrics["teacher/loss"].update(loss_teacher.data.cpu().numpy())
            metrics["teacher/lr"].update(
                np.squeeze(sched_teacher.get_last_lr()))

            opt_teacher.zero_grad()
            loss_teacher.mean().backward()
            opt_teacher.step()
            metrics["teacher/grad_norm"].update(
                grad_norm(
                    p.grad
                    for p in model.teacher.parameters()).data.cpu().numpy())
            sched_teacher.step()

            # copy student weights #####################################################################################

            with torch.no_grad():
                for p, p_prime in zip(model.student.parameters(),
                                      h_model_student.parameters()):
                    p.copy_(p_prime)

    if epoch % config.log_interval != 0:
        return

    writer = SummaryWriter(os.path.join(config.experiment_path, "train"))
    with torch.no_grad():
        for k in metrics:
            writer.add_scalar(k,
                              metrics[k].compute_and_reset(),
                              global_step=epoch)
        writer.add_image(
            "x_image",
            torchvision.utils.make_grid(denormalize(x_image),
                                        nrow=compute_nrow(x_image),
                                        normalize=True),
            global_step=epoch,
        )
        writer.add_image(
            "u_image",
            torchvision.utils.make_grid(denormalize(u_image),
                                        nrow=compute_nrow(u_image),
                                        normalize=True),
            global_step=epoch,
        )

    writer.flush()
    writer.close()
Esempio n. 12
0
def train_epoch(model, data_loader, fold_probs, optimizer, scheduler, epoch,
                config):
    writer = SummaryWriter(
        os.path.join(config.experiment_path, 'F{}'.format(config.fold),
                     'train'))
    metrics = {
        'loss': Mean(),
        'loss_hist': Concat(),
        'entropy': Mean(),
        'lr': Last(),
    }

    model.train()
    for images, targets, indices in tqdm(data_loader,
                                         desc='[F{}][epoch {}] train'.format(
                                             config.fold, epoch)):
        images, targets, indices = images.to(DEVICE), targets.to(
            DEVICE), indices.to(DEVICE)

        if epoch >= config.train.self_distillation.start_epoch:
            targets = weighted_sum(
                targets, fold_probs[indices],
                config.train.self_distillation.target_weight)
        if config.train.cutmix is not None:
            if np.random.uniform() > (epoch - 1) / (config.epochs - 1):
                images, targets = utils.cutmix(images, targets,
                                               config.train.cutmix)

        logits, etc = model(images)

        loss = compute_loss(input=logits, target=targets, config=config.train)

        metrics['loss'].update(loss.data.cpu().numpy())
        metrics['loss_hist'].update(loss.data.cpu().numpy())
        metrics['entropy'].update(compute_entropy(logits).data.cpu().numpy())
        metrics['lr'].update(np.squeeze(scheduler.get_lr()))

        loss.mean().backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

        # FIXME:
        if epoch >= config.train.self_distillation.start_epoch:
            probs = torch.cat(
                [i.softmax(-1) for i in split_target(logits.detach())], -1)
            fold_probs[indices] = weighted_sum(
                fold_probs[indices], probs,
                config.train.self_distillation.pred_ewa)

    for k in metrics:
        if k.endswith('_hist'):
            writer.add_histogram(k,
                                 metrics[k].compute_and_reset(),
                                 global_step=epoch)
        else:
            writer.add_scalar(k,
                              metrics[k].compute_and_reset(),
                              global_step=epoch)
    writer.add_image('images',
                     torchvision.utils.make_grid(images,
                                                 nrow=compute_nrow(images),
                                                 normalize=True),
                     global_step=epoch)
    if 'stn' in etc:
        writer.add_image('stn',
                         torchvision.utils.make_grid(etc['stn'],
                                                     nrow=compute_nrow(
                                                         etc['stn']),
                                                     normalize=True),
                         global_step=epoch)

    writer.flush()
    writer.close()
Esempio n. 13
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)
    del config_path, kwargs

    writer = SummaryWriter(config.experiment_path)

    seed_torch(config.seed)
    env = wrappers.Batch(build_env(config))
    if config.render:
        env = wrappers.TensorboardBatchMonitor(env, writer, config.log_interval)
    env = wrappers.torch.Torch(env, device=DEVICE)
    env.seed(config.seed)

    model = Model(config.model, env.observation_space, env.action_space)
    model = model.to(DEVICE)
    if config.restore_path is not None:
        model.load_state_dict(torch.load(config.restore_path))
    optimizer = build_optimizer(config.opt, model.parameters())
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.episodes)

    metrics = {
        "loss": Mean(),
        "lr": Last(),
        "eps": FPS(),
        "ep/length": Mean(),
        "ep/return": Mean(),
        "rollout/reward": Mean(),
        "rollout/advantage": Mean(),
        "rollout/entropy": Mean(),
    }

    # training loop ====================================================================================================
    for episode in tqdm(range(config.episodes), desc="training"):
        hist = History()
        s = env.reset()
        h = model.zero_state(1)
        d = torch.ones(1, dtype=torch.bool)

        model.eval()
        with torch.no_grad():
            while True:
                trans = hist.append_transition()

                trans.record(state=s, hidden=h, done=d)
                a, _, h = model(s, h, d)
                a = a.sample()
                s, r, d, info = env.step(a)
                trans.record(action=a, reward=r)

                if d:
                    break

        # optimization =================================================================================================
        model.train()

        # build rollout
        rollout = hist.full_rollout()

        # loss
        loss = compute_loss(env, model, rollout, metrics, config)

        # metrics
        metrics["loss"].update(loss.data.cpu().numpy())
        metrics["lr"].update(np.squeeze(scheduler.get_last_lr()))

        # training
        optimizer.zero_grad()
        loss.mean().backward()
        if config.grad_clip_norm is not None:
            nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm)
        optimizer.step()
        scheduler.step()

        metrics["eps"].update(1)
        metrics["ep/length"].update(info[0]["episode"]["l"])
        metrics["ep/return"].update(info[0]["episode"]["r"])

        if episode % config.log_interval == 0 and episode > 0:
            for k in metrics:
                writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=episode)
            torch.save(
                model.state_dict(),
                os.path.join(config.experiment_path, "model_{}.pth".format(episode)),
            )
Esempio n. 14
0
def main(config_path, **kwargs):
    config = load_config(config_path, **kwargs)

    transform, update_transform = build_transform()

    if config.dataset == "mnist":
        dataset = torchvision.datasets.MNIST(config.dataset_path,
                                             transform=transform,
                                             download=True)
    elif config.dataset == "celeba":
        dataset = torchvision.datasets.ImageFolder(config.dataset_path,
                                                   transform=transform)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.workers,
        drop_last=True,
    )

    model = nn.ModuleDict({
        "discriminator":
        Discriminator(config.image_size),
        "generator":
        Generator(config.image_size, config.latent_size),
    })
    model.to(DEVICE)
    if config.restore_path is not None:
        model.load_state_dict(torch.load(config.restore_path))

    discriminator_opt = torch.optim.Adam(model.discriminator.parameters(),
                                         lr=config.opt.lr,
                                         betas=config.opt.beta,
                                         eps=1e-8)
    generator_opt = torch.optim.Adam(model.generator.parameters(),
                                     lr=config.opt.lr,
                                     betas=config.opt.beta,
                                     eps=1e-8)

    noise_dist = torch.distributions.Normal(0, 1)

    writer = SummaryWriter(config.experiment_path)
    metrics = {
        "loss/discriminator": Mean(),
        "loss/generator": Mean(),
        "level": Last(),
        "alpha": Last(),
    }

    for epoch in range(1, config.epochs + 1):
        model.train()

        level, _ = compute_level(epoch - 1, config.epochs, 0, len(data_loader),
                                 config.image_size, config.grow_min_level)
        update_transform(int(4 * 2**level))

        for i, (real, _) in enumerate(
                tqdm(data_loader, desc="epoch {} training".format(epoch))):
            _, a = compute_level(
                epoch - 1,
                config.epochs,
                i,
                len(data_loader),
                config.image_size,
                config.grow_min_level,
            )

            real = real.to(DEVICE)

            # discriminator ############################################################################################
            discriminator_opt.zero_grad()

            # real
            scores = model.discriminator(real, level=level, a=a)
            loss = F.softplus(-scores)
            loss.mean().backward()
            loss_real = loss

            # fake
            noise = noise_dist.sample(
                (config.batch_size, config.latent_size)).to(DEVICE)
            fake = model.generator(noise, level=level, a=a)
            assert real.size() == fake.size()
            scores = model.discriminator(fake, level=level, a=a)
            loss = F.softplus(scores)
            loss.mean().backward()
            loss_fake = loss

            discriminator_opt.step()
            metrics["loss/discriminator"].update(
                (loss_real + loss_fake).data.cpu().numpy())

            # generator ################################################################################################
            generator_opt.zero_grad()

            # fake
            noise = noise_dist.sample(
                (config.batch_size, config.latent_size)).to(DEVICE)
            fake = model.generator(noise, level=level, a=a)
            assert real.size() == fake.size()
            scores = model.discriminator(fake, level=level, a=a)
            loss = F.softplus(-scores)
            loss.mean().backward()

            generator_opt.step()
            metrics["loss/generator"].update(loss.data.cpu().numpy())

            metrics["level"].update(level)
            metrics["alpha"].update(a)

        for k in metrics:
            writer.add_scalar(k,
                              metrics[k].compute_and_reset(),
                              global_step=epoch)
        writer.add_image("real",
                         utils.make_grid((real + 1) / 2),
                         global_step=epoch)
        writer.add_image("fake",
                         utils.make_grid((fake + 1) / 2),
                         global_step=epoch)

        torch.save(
            model.state_dict(),
            os.path.join(config.experiment_path, "model_{}.pth".format(epoch)))