Ejemplo n.º 1
0
    def train(self):
        total_step = len(self.data_loader)
        optimizer = Adam(self.transfer_net.parameters(), lr=self.lr)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       self.decay_epoch, 0.5)
        content_criterion = nn.MSELoss()
        stlye_criterion = nn.MSELoss()
        self.transfer_net.train()
        self.vgg.eval()

        for epoch in range(self.epoch, self.num_epoch):
            if not os.path.exists(
                    os.path.join(self.sample_dir, self.style_image_name,
                                 f"{epoch}")):
                os.makedirs(
                    os.path.join(self.sample_dir, self.style_image_name,
                                 f"{epoch}"))
            for step, image in enumerate(self.data_loader):

                optimizer.zero_grad()
                image = image.to(self.device)
                transformed_image = self.transfer_net(image)

                image_feature = self.vgg(image)
                transformed_image_feature = self.vgg(transformed_image)

                content_loss = self.content_weight * content_criterion(
                    image_feature.relu2_2, transformed_image_feature.relu2_2)

                style_loss = 0
                for ft_y, gm_s in zip(transformed_image_feature,
                                      self.gram_style):
                    gm_y = gram_matrix(ft_y)
                    style_loss += stlye_criterion(gm_y,
                                                  gm_s[:self.batch_size, :, :])
                style_loss *= self.style_weight

                total_loss = content_loss + style_loss

                total_loss.backward(retain_graph=True)
                optimizer.step()

                if step % 10 == 0:
                    print(
                        f"[Epoch {epoch}/{self.num_epoch}] [Batch {step}/{total_step}] "
                        f"[Style loss: {style_loss.item():.4}] [Content loss loss: {content_loss.item():.4}]"
                    )
                    if step % 100 == 0:
                        image = torch.cat((image, transformed_image), dim=2)
                        save_image(image,
                                   os.path.join(self.sample_dir,
                                                self.style_image_name,
                                                f"{epoch}", f"{step}.png"),
                                   normalize=False)

            torch.save(
                self.transfer_net.state_dict(),
                os.path.join(self.checkpoint_dir, self.style_image_name,
                             f"TransferNet_{epoch}.pth"))
            lr_scheduler.step()
Ejemplo n.º 2
0
def bc_step(config: ParamDict, policy: Policy, demo):
    lr_init, lr_factor, l2_reg, bc_method, batch_sz, i_iter = \
        config.require("lr", "lr factor", "l2 reg", "bc method", "batch size", "current training iter")
    states, actions = demo

    # ---- annealing on learning rate ---- #
    lr = max(lr_init + lr_factor * i_iter, 1.e-8)

    optimizer = Adam(policy.policy_net.parameters(),
                     weight_decay=l2_reg,
                     lr=lr)

    # ---- define BC from demonstrations ---- #
    total_len = states.size(0)
    idx = torch.randperm(total_len, device=policy.device)
    err = 0.
    for i_b in range(int(total_len // batch_sz) + 1):
        idx_b = idx[i_b * batch_sz:(i_b + 1) * batch_sz]
        s_b = states[idx_b]
        a_b = actions[idx_b]

        optimizer.zero_grad()
        a_mean_pred, a_logvar_pred = policy.policy_net(s_b)
        bc_loss = mse_loss(a_mean_pred + 0. * a_logvar_pred, a_b)
        err += bc_loss.item() * s_b.size(0) / total_len
        bc_loss.backward()
        optimizer.step()

    return err
Ejemplo n.º 3
0
def train():
    tb_writer = SummaryWriter('tb_output')
    device = 'cuda:0' if CONF['GPU'] else 'cpu'
    model: nn.Module = CaseModel()
    # tb_writer.add_graph(model)
    model.train()
    
    train_dataset = CasRelDataset(path_or_json=CONF['TRAIN_DATA_PATH'])
    eval_dataset = CasRelDataset(path_or_json=CONF['EVAL_DATA_PATH'])
    dataloader = DataLoader(train_dataset,
                            batch_size=CONF['batch_size'],
                            shuffle=True,
                            collate_fn=collate_casrel)
    loss_func = BCEWithLogitsLoss()
    best_loss = 1e3
    optim = Adam(model.parameters(), lr=1e-5)
    global_steps = 0

    for epoch_num in range(Epochs):
        epoch_loss = 0.0
        model = model.to(device=device)
        for (batch_tokens,
             batch_mask,
             batch_sub_head,
             batch_sub_tail,
             batch_sub_head_arr,
             batch_sub_tail_arr,
             batch_obj_head_arr,
             batch_obj_tail_arr) in tqdm(dataloader, f'Epoch {epoch_num:3.0f}/{Epochs}', len(dataloader)):
            batch_tokens, batch_mask, batch_sub_head, batch_sub_tail, batch_sub_head_arr, batch_sub_tail_arr,batch_obj_head_arr, batch_obj_tail_arr = list(
                map(lambda x: x.to(device),
                    (batch_tokens, batch_mask, batch_sub_head, batch_sub_tail,batch_sub_head_arr, batch_sub_tail_arr,batch_obj_head_arr, batch_obj_tail_arr)
                    )
            )
            sub_head_pred, sub_tail_pred, obj_head_pred, obj_tail_pred = model(batch_tokens,
                                                                               batch_mask,
                                                                               batch_sub_head,
                                                                               batch_sub_tail)
            sub_head_loss = loss_func(sub_head_pred.squeeze(), batch_sub_head_arr)
            sub_tail_loss = loss_func(sub_tail_pred.squeeze(), batch_sub_tail_arr)
            obj_head_loss = loss_func(obj_head_pred, batch_obj_head_arr)
            obj_tail_loss = loss_func(obj_tail_pred, batch_obj_tail_arr)
            loss = sub_head_loss + sub_tail_loss + obj_head_loss + obj_tail_loss
            epoch_loss += loss
            logger.info(f'every batch loss: {loss}')
            global_steps += 1
            tb_writer.add_scalar('train_loss', loss, global_steps)
            optim.zero_grad()
            loss.backward()
            optim.step()
        # end one epoch

        p, r, f = metric(model.to('cpu'), eval_dataset)
        logger.info(f'epoch:{epoch_num + 1:3.0f}, precision: {p:5.4f}, recall: {r:5.4f}, f1-score: {f:5.4f}')
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            save_model = CONF['SAVE_MODEL']
            if not os.path.exists(os.path.dirname(save_model)):
                os.makedirs(os.path.dirname(save_model))
            torch.save(model.state_dict(), save_model)
Ejemplo n.º 4
0
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # net
        self.linear0 = nn.Linear(2 * IN_DIM, IN_DIM)
        self.relu0 = nn.LeakyReLU(inplace=True)

        self.linear1 = nn.Linear(IN_DIM, IN_DIM // 2)
        self.relu1 = nn.LeakyReLU(inplace=True)

        self.linear2 = nn.Linear(IN_DIM // 2, 1)
        self.act_fn = nn.Tanh()

        # opt
        self.opt = Adam(self.parameters(), lr=1e-4)

    def forward(self, fea_q: Variable, fea_d: Variable):
        bs_anchor = fea_q.size()[0]
        bs_candi = fea_d.size()[0]
        assert bs_anchor == bs_candi or bs_anchor == 1

        fea_q = fea_q.expand(bs_candi, fea_q.size()[1])

        x = torch.cat([fea_q, fea_d], 1)

        x = self.linear0(x)
        x = self.relu0(x)

        x = self.linear1(x)
        x = self.relu1(x)

        x = self.linear2(x)

        x = x.view(-1)

        return x

    def bp(self, real_logit: Variable, fake_logit: Variable):
        size = real_logit.size()
        # real_label = Variable(torch.normal(torch.ones(size), torch.zeros(size) + 0.02)).cuda()
        # fake_label = Variable(torch.normal(torch.zeros(size), torch.zeros(size) + 0.02)).cuda()

        real_label = Variable(torch.ones(size)).cuda()
        fake_label = Variable(torch.zeros(size)).cuda()

        margins = F.threshold(
            0.8 - (torch.sigmoid(real_logit) - torch.sigmoid(fake_logit)), 0,
            0)
        loss = torch.mean(F.binary_cross_entropy_with_logits(real_logit, real_label, size_average=False) + \
                          F.binary_cross_entropy_with_logits(fake_logit, fake_label, size_average=False)) + \
               torch.mean(margins)
        # loss = -(torch.mean(torch.log(score_real + 1e-6)) - torch.mean(torch.log(.5 + score_fake / 2 + 1e-6)))

        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        return loss
Ejemplo n.º 5
0
def ppo_step(config: ParamDict, batch: SampleBatch, policy: PolicyWithValue):
    lr, l2_reg, clip_epsilon, policy_iter, i_iter, max_iter, mini_batch_sz = \
        config.require("lr", "l2 reg", "clip eps", "optimize policy epochs",
                       "current training iter", "max iter", "optimize batch size")
    lam_entropy = 0.
    states, actions, advantages, returns = get_tensor(batch, policy.device)

    lr_mult = max(1.0 - i_iter / max_iter, 0.)
    clip_epsilon = clip_epsilon * lr_mult

    optimizer_policy = Adam(policy.policy_net.parameters(),
                            lr=lr * lr_mult,
                            weight_decay=l2_reg)
    optimizer_value = Adam(policy.value_net.parameters(),
                           lr=lr * lr_mult,
                           weight_decay=l2_reg)

    with torch.no_grad():
        fixed_log_probs = policy.policy_net.get_log_prob(states,
                                                         actions).detach()

    for _ in range(policy_iter):
        inds = torch.randperm(states.size(0))
        """perform mini-batch PPO update"""
        for i_b in range(inds.size(0) // mini_batch_sz):
            slc = slice(i_b * mini_batch_sz, (i_b + 1) * mini_batch_sz)

            states_i = states[slc]
            actions_i = actions[slc]
            returns_i = returns[slc]
            advantages_i = advantages[slc]
            log_probs_i = fixed_log_probs[slc]
            """update critic"""
            for _ in range(1):
                value_loss = F.mse_loss(policy.value_net(states_i), returns_i)
                optimizer_value.zero_grad()
                value_loss.backward()
                torch.nn.utils.clip_grad_norm_(policy.value_net.parameters(),
                                               0.5)
                optimizer_value.step()
            """update policy"""
            log_probs, entropy = policy.policy_net.get_log_prob_entropy(
                states_i, actions_i)
            ratio = (log_probs - log_probs_i).clamp_max(15.).exp()
            surr1 = ratio * advantages_i
            surr2 = torch.clamp(ratio, 1.0 - clip_epsilon,
                                1.0 + clip_epsilon) * advantages_i
            policy_surr = -torch.min(
                surr1, surr2).mean() - entropy.mean() * lam_entropy
            optimizer_policy.zero_grad()
            policy_surr.backward()
            torch.nn.utils.clip_grad_norm_(policy.policy_net.parameters(), 0.5)
            optimizer_policy.step()
Ejemplo n.º 6
0
class Selector(nn.Module):
    def __init__(self):
        super(Selector, self).__init__()

        # net
        self.linear_candi = nn.Linear(IN_DIM, IN_DIM // 2)
        self.relu_candi = nn.LeakyReLU(inplace=True)

        self.linear_anchor = nn.Linear(IN_DIM, IN_DIM // 2)
        self.relu_anchor = nn.LeakyReLU(inplace=True)

        self.linear1 = nn.Linear(IN_DIM, IN_DIM // 2)
        self.relu1 = nn.LeakyReLU(inplace=True)

        self.linear2 = nn.Linear(IN_DIM // 2, 1)
        self.act_fn = nn.Tanh()

        # opt
        self.opt = Adam(self.parameters(), lr=1e-4)

    def forward(self, fea_q, fea_d):
        bs_anchor = fea_q.size()[0]
        bs_candi = fea_d.size()[0]
        assert bs_anchor == bs_candi or bs_anchor == 1

        x_anchor = self.linear_anchor(fea_q)
        x_anchor = self.relu_anchor(x_anchor)
        x_anchor = x_anchor.expand(bs_candi, x_anchor.size()[1])

        x_candi = self.linear_candi(fea_d)
        x_candi = self.relu_candi(x_candi)

        x = torch.cat([x_anchor, x_candi], 1)

        x = self.linear1(x)
        x = self.relu1(x)

        x = self.linear2(x)

        logit = x.view(-1)

        return logit

    def bp(self, fake_logit: Variable, prob):
        bs = fake_logit.size()[0]
        self.opt.zero_grad()
        reward = torch.tanh(fake_logit.detach())
        # loss = -(torch.mean(torch.log(prob) * reward)).backward()
        torch.log(prob).backward(-reward / bs)

        self.opt.step()
Ejemplo n.º 7
0
def train(
    config: TrainConfig,
    model: BartForConditionalGeneration,
    train_dataloader: DataLoader,
    dev_dataloader: DataLoader,
    optimizer: Adam,
    logger: logging.Logger,
    device=torch.device,
):
    """ 지정된 Epoch만큼 모델을 학습시키는 함수입니다. """
    model.to(device)
    global_step = 0
    for epoch in range(1, config.num_epochs + 1):
        model.train()
        loss_sum = 0.0
        for data in train_dataloader:
            global_step += 1
            data = _change_device(data, device)
            optimizer.zero_grad()
            output = model.forward(
                input_ids=data[0],
                attention_mask=data[1],
                decoder_input_ids=data[2],
                labels=data[3],
                decoder_attention_mask=data[4],
                return_dict=True,
            )
            loss = output["loss"]
            loss.backward()
            loss_sum += loss.item()

            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            if global_step % config.train_log_interval == 0:
                mean_loss = loss_sum / config.train_log_interval
                logger.info(
                    f"Epoch {epoch} Step {global_step} " f"Loss {mean_loss:.4f} Perplexity {math.exp(mean_loss):8.2f}"
                )
                loss_sum = 0.0
            if global_step % config.dev_log_interval == 0:
                _validate(model, dev_dataloader, logger, device)
            if global_step % config.save_interval == 0:
                model.save_pretrained(f"{config.save_model_file_prefix}_{global_step}")
Ejemplo n.º 8
0
def a2c_step(config: ParamDict, policy: PolicyWithValue,
             replay_memory: StepDictList):
    config.require("l2 reg")
    l2_reg = config["l2 reg"]

    policy.estimate_advantage(replay_memory)
    states, actions, returns, advantages = get_tensor(policy, replay_memory,
                                                      policy.device)
    """update critic"""
    update_value_net(policy.value_net, returns, states, l2_reg)
    """update policy"""
    optimizer_policy = Adam(policy.policy_net.parameters(),
                            lr=1.e-3,
                            weight_decay=l2_reg)
    log_probs = policy.policy_net.get_log_prob(states, actions)
    policy_loss = -(log_probs * advantages).mean()
    optimizer_policy.zero_grad()
    policy_loss.backward()
    optimizer_policy.step()
Ejemplo n.º 9
0
def ppo_step(config: ParamDict, replay_memory: StepDictList,
             policy: PolicyWithValue):
    lr, l2_reg, clip_epsilon, policy_iter, i_iter, max_iter, mini_batch_sz = \
        config.require("lr", "l2 reg", "clip eps", "optimize policy epochs",
                       "current training iter", "max iter", "optimize batch size")
    lam_entropy = 0.0
    states, actions, advantages, returns = get_tensor(policy, replay_memory,
                                                      policy.device)
    """update critic"""
    update_value_net(policy.value_net, states, returns, l2_reg)
    """update policy"""
    lr_mult = max(1.0 - i_iter / max_iter, 0.)
    clip_epsilon = clip_epsilon * lr_mult
    optimizer = Adam(policy.policy_net.parameters(),
                     lr=lr * lr_mult,
                     weight_decay=l2_reg)

    with torch.no_grad():
        fixed_log_probs = policy.policy_net.get_log_prob(states,
                                                         actions).detach()

    inds = torch.arange(states.size(0))

    for _ in range(policy_iter):
        np.random.shuffle(inds)
        """perform mini-batch PPO update"""
        for i_b in range(int(np.ceil(states.size(0) / mini_batch_sz))):
            ind = inds[i_b * mini_batch_sz:min((i_b + 1) *
                                               mini_batch_sz, inds.size(0))]

            log_probs, entropy = policy.policy_net.get_log_prob_entropy(
                states[ind], actions[ind])
            ratio = torch.exp(log_probs - fixed_log_probs[ind])
            surr1 = ratio * advantages[ind]
            surr2 = torch.clamp(ratio, 1.0 - clip_epsilon,
                                1.0 + clip_epsilon) * advantages[ind]
            policy_surr = -torch.min(surr1, surr2).mean()
            policy_surr -= entropy.mean() * lam_entropy
            optimizer.zero_grad()
            policy_surr.backward()
            torch.nn.utils.clip_grad_norm_(policy.policy_net.parameters(), 0.5)
            optimizer.step()
Ejemplo n.º 10
0
class TrainTest:
    """
    initialize your model and build you optimizer and loss function
    """
    def __init__(self):
        self.model = _Classifier().to(DEVICE)
        self.optimizer = Adam(self.model.parameters(), lr=Params.LR)
        self.criterion = CrossEntropyLoss()

    def train_test(self, X, X_len, **kwargs):
        """
        train and test function

        Args:
            X (tensor): training data
            X_len (tensor): length

        Returns:
            y_pred (int): predict values
        """
        y = kwargs.pop('y', None)

        self.optimizer.zero_grad()

        if y is None:
            # eval can disable dropout
            self.model.eval()
            with torch.no_grad():
                output = self.model(X, X_len)
                y_pred = output.detach().argmax(1)
            return y_pred
        else:
            self.model.train()
            output = self.model(X, X_len)
            loss = self.criterion(output, y)
            # if multi gpu, remember use loss.mean()
            loss.backward()
            self.optimizer.step()

    def save(self):
        torch.save(self.model.state_dict(), Params.PATH_SAVE)
Ejemplo n.º 11
0
def update_value_net(value_net,
                     states,
                     returns,
                     l2_reg,
                     mini_batch_sz=512,
                     iter=10):
    optim = Adam(value_net.parameters(), weight_decay=l2_reg)

    inds = torch.arange(returns.size(0))

    for i in range(iter):
        np.random.shuffle(inds)
        for i_b in range(returns.size(0) // mini_batch_sz):
            b_ind = inds[i_b * mini_batch_sz:min((i_b + 1) *
                                                 mini_batch_sz, inds.size(0))]

            value_loss = (value_net(states[b_ind]) -
                          returns[b_ind]).pow(2).mean()
            optim.zero_grad()
            value_loss.backward()
            optim.step()
Ejemplo n.º 12
0
    def train(self):
        total_step = len(self.data_loader)
        optimizer = Adam(self.transfer_net.parameters(), lr=self.lr)
        loss = nn.MSELoss()
        self.transfer_net.train()

        for epoch in range(self.epoch, self.num_epoch):
            for step, image in enumerate(self.data_loader):
                image = image.to(self.device)
                transformed_image = self.transfer_net(image)

                image_feature = self.vgg(image)
                transformed_image_feature = self.vgg(transformed_image)

                content_loss = self.content_weight * loss(
                    image_feature, transformed_image_feature)

                style_loss = 0
                for ft_y, gm_s in zip(transformed_image_feature,
                                      self.gram_style):
                    gm_y = gram_matrix(ft_y)
                    style_loss += load_image(gm_y,
                                             gm_s[:self.batch_size, :, :])
                style_loss *= self.style_weight

                total_loss = content_loss + style_loss

                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                if step % 10 == 0:
                    print(
                        f"[Epoch {epoch}/{self.num_epoch}] [Batch {step}/{total_step}] "
                        f"[Style loss: {style_loss.item()}] [Content loss loss: {content_loss.item()}]"
                    )
            torch.save(
                self.transfer_net.state_dict(),
                os.path.join(self.checkpoint_dir, f"TransferNet_{epoch}.pth"))
Ejemplo n.º 13
0
def generate_layer_to_rgb(layer_dim):
    print('Generating layer-to-rgb convolution...')
    start_time = time.time()

    to_rgb = nn.Sequential(nn.Conv2d(layer_dim, 3, 1, 1, 0), nn.Tanh())

    to_layer = nn.Conv2d(3, layer_dim, 1, 1, 0)
    ae = nn.Sequential(to_rgb, to_layer)
    optim = Adam(ae.parameters(), lr=1e-3)

    for _ in range(256):
        noise = torch.randn((128, layer_dim, 1, 1))
        loss = F.mse_loss(ae(noise), noise)

        optim.zero_grad()
        loss.backward()
        optim.step()

    elapsed_time = (time.time() - start_time) * 1000
    print(f"Finished in {elapsed_time:.1f} ms.")

    return nn.Sequential(to_rgb, nn.UpsamplingNearest2d(scale_factor=16))
Ejemplo n.º 14
0
class Trainer():
    def __init__(self, model, n_epochs=10, criterion=nn.CrossEntropyLoss):

        self.trainset, self.testset = fashion_mnist()
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model = model
        self.model.to(self.device)
        self.criterion = criterion()
        self.optim = Adam(model.parameters(), lr=0.01)

    def train(self, n_epochs, batch_size, print_every=100):
        losses = []
        accuracies = []
        for e in range(n_epochs):
            iter = 0
            for x, y in tqdm(DataLoader(self.trainset, batch_size=batch_size)):
                x, y = x.to(self.device), y.to(self.device)

                def closure():
                    self.optim.zero_grad()
                    # returning output and feature map positive and negative
                    pred = self.model(x)
                    loss = self.criterion(pred, y)
                    loss.backward()
                    losses.append(loss.item())
                    accuracies.append((pred.argmax(1) == y).float().mean())
                    if iter % print_every == 0:
                        print('\r' + ' epoch ' + str(e) + ' |  loss : ' +
                              str(loss.item()) + ' | acc : ' +
                              str(accuracies[-1]))
                    return loss

                self.optim.step(closure)
                iter += 1

    def test(self):
        pass
Ejemplo n.º 15
0
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.feature_extractor = resnet50()  # Already pretrained
        # self.feature_extractor = resnet50(pretrained_path=None)
        self.selector = Selector()
        self.dis = Discriminator()
        self.optmzr_select = Adam(self.selector.parameters(), lr=1e-3)
        self.optmzr_dis = Adam(self.dis.parameters(), lr=1e-3)

    def forward(self, anchor: Variable, real_data: Variable, fake_data: Variable):
        assert len(anchor.size()) == 4 and len(anchor.size()) == 4

        fea_anchor = self.feature_extractor(anchor)
        fea_real = self.feature_extractor(real_data)
        fea_fake = self.feature_extractor(fake_data)

        # not train_feature:
        fea_anchor = fea_anchor.detach()
        fea_real = fea_real.detach()
        fea_fake = fea_fake.detach()

        score_real = self.dis(fea_anchor, fea_real)
        score_fake = self.dis(fea_anchor, fea_fake)

        return score_real, score_fake

    def bp_dis(self, score_real, score_fake):
        real_label = Variable(torch.normal(torch.ones(score_real.size()), torch.zeros(score_real.size()) + 0.05)).cuda()
        fake_label = Variable(
            torch.normal(torch.zeros(score_real.size()), torch.zeros(score_real.size()) + 0.05)).cuda()
        loss = torch.mean(F.binary_cross_entropy(score_real, real_label, size_average=False) + \
                          F.binary_cross_entropy(score_fake, fake_label, size_average=False))

        # loss = -(torch.mean(torch.log(score_real + 1e-6)) - torch.mean(torch.log(.5 + score_fake / 2 + 1e-6)))

        self.optmzr_dis.zero_grad()
        loss.backward()
        return self.optmzr_dis.step()

    def bp_select(self, score_fake: Variable, fake_prob):
        # torch.mean(torch.log(prob) * torch.log(1 - score_fake), 0)
        n_sample = score_fake.size()[0]
        self.optmzr_dis.zero_grad()
        re = (score_fake.data - .5) * 2
        torch.log(fake_prob).backward(re / n_sample)
Ejemplo n.º 16
0
def train(model, SRC, TRG, MODEL_PATH, FORCE_MAX_LEN=50):
    model.train()
    optimizer = Adam(model.parameters(), lr=hp.LR, betas=(0.9, 0.98), eps=1e-9)
    criterion = CrossEntropyLoss(ignore_index=TRG.vocab.stoi["<pad>"])

    for epoch in tqdm(range(hp.EPOCHS)):

        for step, batch in enumerate(train_iter):
            global_step = epoch * len(train_iter) + step

            model.train()
            optimizer.zero_grad()
            optimizer = custom_lr_optimizer(optimizer, global_step)

            src = batch.src.T
            trg = batch.trg.T

            trg_input = trg[:, :-1]

            preds, _, _, _ = model(src, trg_input)
            ys = trg[:, 1:]

            loss = criterion(preds.permute(0, 2, 1), ys)
            loss.mean().backward()
            optimizer.step()

            if global_step % 50 == 0:
                print("#" * 90)

                rand_index = random.randrange(hp.BATCH_SIZE)

                model.eval()

                v = next(iter(val_iter))
                v_src, v_trg = v.src.T, v.trg.T

                v_trg_inp = v_trg[:, :-1].detach()
                v_trg_real = v_trg[:, 1:].detach()

                v_predictions, _, _, _ = model(v_src, v_trg_inp)
                max_args = v_predictions[rand_index].argmax(-1)
                print("For random element in VALIDATION batch (real/pred)...")
                print([
                    TRG.vocab.itos[word_idx]
                    for word_idx in v_trg_real[rand_index, :]
                ])
                print([TRG.vocab.itos[word_idx] for word_idx in max_args])

                print("Length til first <PAD> (real -> pred)...")
                try:
                    pred_PAD_idx = max_args.tolist().index(3)
                except:
                    pred_PAD_idx = None

                print(v_trg_real[rand_index, :].tolist().index(3), "  --->  ",
                      pred_PAD_idx)

                val_loss = criterion(v_predictions.permute(0, 2, 1),
                                     v_trg_real)
                print("TRAINING LOSS:", loss.mean().item())
                print("VALIDATION LOSS:", val_loss.mean().item())

                print("#" * 90)

                writer.add_scalar("Training Loss",
                                  loss.mean().detach().item(), global_step)
                writer.add_scalar("Validation Loss",
                                  val_loss.mean().detach().item(), global_step)
        torch.save(model, MODEL_PATH)
Ejemplo n.º 17
0
class DPProcessor(object):
    def __init__(self, cfg_path):
        with open(cfg_path, 'r') as rf:
            self.cfg = yaml.safe_load(rf)
        self.data_cfg = self.cfg['data']
        self.model_cfg = self.cfg['model']
        self.optim_cfg = self.cfg['optim']
        self.val_cfg = self.cfg['val']
        print(self.data_cfg)
        print(self.model_cfg)
        print(self.optim_cfg)
        print(self.val_cfg)
        self.tdata = MSCOCO(
            img_root=self.data_cfg['train_img_root'],
            ann_path=self.data_cfg['train_ann_path'],
            debug=self.data_cfg['debug'],
            augment=True,
        )
        self.tloader = DataLoader(dataset=self.tdata,
                                  batch_size=self.data_cfg['batch_size'],
                                  num_workers=self.data_cfg['num_workers'],
                                  collate_fn=self.tdata.collate_fn,
                                  shuffle=True)
        self.vdata = MSCOCO(
            img_root=self.data_cfg['val_img_root'],
            ann_path=self.data_cfg['val_ann_path'],
            debug=False,
            augment=False,
        )
        self.vloader = DataLoader(dataset=self.vdata,
                                  batch_size=self.data_cfg['batch_size'],
                                  num_workers=self.data_cfg['num_workers'],
                                  collate_fn=self.vdata.collate_fn,
                                  shuffle=False)
        print("train_data: ", len(self.tdata), " | ", "val_data: ",
              len(self.vdata))
        print("train_iter: ", len(self.tloader), " | ", "val_iter: ",
              len(self.vloader))
        model: torch.nn.Module = getattr(
            eval(self.model_cfg['type']),
            self.model_cfg['name'])(pretrained=self.model_cfg['pretrained'],
                                    num_classes=self.model_cfg['num_joints'],
                                    reduction=self.model_cfg['reduction'])
        self.scaler = amp.GradScaler(
            enabled=True) if self.optim_cfg['amp'] else None
        self.optimizer = Adam(model.parameters(), lr=self.optim_cfg['lr'])
        self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=self.optim_cfg['milestones'],
            gamma=self.optim_cfg['gamma'])
        # self.lr_scheduler = IterWarmUpCosineDecayMultiStepLRAdjust(
        #     init_lr=self.optim_cfg['lr'],
        #     milestones=self.optim_cfg['milestones'],
        #     warm_up_epoch=1,
        #     iter_per_epoch=len(self.tloader),
        #     epochs=self.optim_cfg['epochs']
        # )

        assert torch.cuda.is_available(), "training only support cuda"
        assert torch.cuda.device_count() >= len(
            self.cfg['gpus']), "not have enough gpus"
        self.inp_device = torch.device("cuda:{:d}".format(self.cfg['gpus'][0]))
        self.out_device = torch.device("cuda:{:d}".format(
            self.cfg['gpus'][-1]))
        model.to(self.inp_device)
        self.model = nn.DataParallel(model,
                                     device_ids=self.cfg['gpus'],
                                     output_device=self.out_device)
        # self.ema = ModelEMA(self.model)
        self.creterion = nn.MSELoss()
        self.acc_func = HeatMapAcc()
        self.best_ap = 0.
        self.loss_logger = AverageLogger()
        self.acc_logger = AverageLogger()
        self.decoder = BasicKeyPointDecoder()

    def train(self, epoch):
        self.loss_logger.reset()
        self.acc_logger.reset()
        self.model.train()
        pbar = tqdm(self.tloader)
        print("#" * 25, "training start", "#" * 25)
        for i, (input_tensors, heat_maps, masks, _, _) in enumerate(pbar):
            input_img = input_tensors.to(self.inp_device)
            targets = heat_maps.to(self.out_device)
            mask = masks.to(self.out_device)
            self.optimizer.zero_grad()
            if self.scaler is None:
                predicts = self.model(input_img)
                loss = 0.5 * self.creterion(
                    predicts.mul(mask[[..., None, None]]),
                    targets.mul(mask[[..., None, None]]))
                loss.backward()
                # nn.utils.clip_grad_norm_(self.model.parameters(), 2)
                # self.lr_scheduler(self.optimizer, i, epoch)
                self.optimizer.step()
            else:
                with amp.autocast(enabled=True):
                    predicts = self.model(input_img)
                    loss = 0.5 * self.creterion(
                        predicts.mul(mask[[..., None, None]]),
                        targets.mul(mask[[..., None, None]]))
                self.scaler.scale(loss).backward()
                # nn.utils.clip_grad_norm_(self.model.parameters(), 2)
                # self.lr_scheduler(self.optimizer, i, epoch)
                self.scaler.step(self.optimizer)
                self.scaler.update()
            # self.ema.update(self.model)
            lr = self.optimizer.param_groups[0]['lr']
            acc = self.acc_func(
                predicts.mul(mask[[..., None, None]]).detach(),
                targets.mul(mask[[..., None, None]]).detach())
            self.loss_logger.update(loss.item())
            self.acc_logger.update(acc.item())
            pbar.set_description(
                "train epoch:{:3d}|iter:{:4d}|loss:{:8.6f}|acc:{:6.4f}|lr:{:8.6f}"
                .format(
                    epoch + 1,
                    i,
                    self.loss_logger.avg(),
                    self.acc_logger.avg() * 100,
                    lr,
                ))
        # self.ema.update_attr(self.model)
        self.lr_scheduler.step()
        print()
        print("#" * 25, "training end", "#" * 25)

    @torch.no_grad()
    def val(self, epoch):
        self.loss_logger.reset()
        self.acc_logger.reset()
        self.model.eval()
        # self.ema.ema.eval()
        pbar = tqdm(self.vloader)
        kps_dict_list = list()
        print("#" * 25, "evaluating start", "#" * 25)
        for i, (input_tensors, heat_maps, masks, trans_invs,
                img_ids) in enumerate(pbar):
            input_img = input_tensors.to(self.inp_device)
            targets = heat_maps.to(self.out_device)
            tran_inv = trans_invs.to(self.out_device)
            mask = masks.to(self.out_device)
            predicts = self.model(input_img)
            loss = 0.5 * self.creterion(predicts.mul(mask[[..., None, None]]),
                                        targets.mul(mask[[..., None, None]]))
            acc = self.acc_func(predicts.mul(mask[[..., None, None]]),
                                targets.mul(mask[[..., None, None]]))
            self.loss_logger.update(loss.item())
            self.acc_logger.update(acc.item())
            pred_kps, scores = self.decoder(predicts, tran_inv)
            kps_to_dict_(pred_kps, scores, img_ids, kps_dict_list)
            pbar.set_description(
                "eval epoch:{:3d}|iter:{:4d}|loss:{:8.6f}|acc:{:6.4f}".format(
                    epoch + 1,
                    i,
                    self.loss_logger.avg(),
                    self.acc_logger.avg() * 100,
                ))
        with open("temp_test.json", "w") as wf:
            json.dump(kps_dict_list, wf)
        val_ap = evaluate_map("temp_test.json",
                              self.data_cfg['val_ann_path'])['AP']
        print(
            "eval epoch:{:d}|mean_loss:{:8.6f}|mean_acc:{:6.4f}|val_ap:{:6.4f}"
            .format(epoch + 1, self.loss_logger.avg(),
                    self.acc_logger.avg() * 100, val_ap))
        print("#" * 25, "evaluating end", "#" * 25)

        cpkt = {
            "ema": self.model.module.state_dict(),
            "epoch": epoch,
        }
        if val_ap > self.best_ap:
            self.best_ap = val_ap
            best_weight_path = os.path.join(
                self.val_cfg['weight_path'],
                "{:s}_best.pth".format(self.model_cfg['type']))
            torch.save(cpkt, best_weight_path)
        last_weight_path = os.path.join(
            self.val_cfg['weight_path'],
            "{:s}_last.pth".format(self.model_cfg['type']))
        torch.save(cpkt, last_weight_path)

    def run(self):
        for epoch in range(self.optim_cfg['epochs']):
            self.train(epoch)
            if (epoch + 1) % self.val_cfg['interval'] == 0:
                self.val(epoch)
Ejemplo n.º 18
0
class SACDrQ(Agent):
    # https://arxiv.org/abs/2004.13649
    def __init__(self,
                 algo_params,
                 env,
                 transition_tuple=None,
                 path=None,
                 seed=-1):
        # environment
        self.env = PixelPybulletGym(
            env,
            image_size=algo_params['image_resize_size'],
            crop_size=algo_params['image_crop_size'])
        self.frame_stack = algo_params['frame_stack']
        self.env = FrameStack(self.env, k=self.frame_stack)
        self.env.seed(seed)
        obs = self.env.reset()
        algo_params.update({
            'state_shape':
            obs.shape,  # make sure the shape is like (C, H, W), not (H, W, C)
            'action_dim': self.env.action_space.shape[0],
            'action_max': self.env.action_space.high,
            'action_scaling': self.env.action_space.high[0],
        })
        # training args
        self.max_env_step = algo_params['max_env_step']
        self.testing_gap = algo_params['testing_gap']
        self.testing_episodes = algo_params['testing_episodes']
        self.saving_gap = algo_params['saving_gap']

        super(SACDrQ, self).__init__(algo_params,
                                     transition_tuple=transition_tuple,
                                     image_obs=True,
                                     training_mode='step_based',
                                     path=path,
                                     seed=seed)
        # torch
        self.encoder = PixelEncoder(self.state_shape)
        self.encoder_target = PixelEncoder(self.state_shape)
        self.network_dict.update({
            'actor':
            StochasticConvActor(self.action_dim,
                                encoder=self.encoder,
                                detach_obs_encoder=True).to(self.device),
            'critic_1':
            ConvCritic(self.action_dim,
                       encoder=self.encoder,
                       detach_obs_encoder=False).to(self.device),
            'critic_1_target':
            ConvCritic(self.action_dim,
                       encoder=self.encoder_target,
                       detach_obs_encoder=True).to(self.device),
            'critic_2':
            ConvCritic(self.action_dim,
                       encoder=self.encoder,
                       detach_obs_encoder=False).to(self.device),
            'critic_2_target':
            ConvCritic(self.action_dim,
                       encoder=self.encoder_target,
                       detach_obs_encoder=True).to(self.device),
            'alpha':
            algo_params['alpha'],
            'log_alpha':
            T.tensor(np.log(algo_params['alpha']),
                     requires_grad=True,
                     device=self.device),
        })
        self.network_keys_to_save = ['actor']
        self.actor_optimizer = Adam(self.network_dict['actor'].parameters(),
                                    lr=self.actor_learning_rate)
        self.critic_1_optimizer = Adam(
            self.network_dict['critic_1'].parameters(),
            lr=self.critic_learning_rate)
        self.critic_2_optimizer = Adam(
            self.network_dict['critic_2'].parameters(),
            lr=self.critic_learning_rate)
        self._soft_update(self.network_dict['critic_1'],
                          self.network_dict['critic_1_target'],
                          tau=1)
        self._soft_update(self.network_dict['critic_2'],
                          self.network_dict['critic_2_target'],
                          tau=1)
        self.target_entropy = -self.action_dim
        self.alpha_optimizer = Adam([self.network_dict['log_alpha']],
                                    lr=self.actor_learning_rate)
        # augmentation args
        self.image_random_shift = T.nn.Sequential(
            T.nn.ReplicationPad2d(4), aug.RandomCrop(self.state_shape[-2:]))
        self.q_regularisation_k = algo_params['q_regularisation_k']
        # training args
        self.warmup_step = algo_params['warmup_step']
        self.actor_update_interval = algo_params['actor_update_interval']
        self.critic_target_update_interval = algo_params[
            'critic_target_update_interval']
        # statistic dict
        self.statistic_dict.update({
            'episode_return': [],
            'env_step_return': [],
            'env_step_test_return': [],
            'alpha': [],
            'policy_entropy': [],
        })

    def run(self, test=False, render=False, load_network_ep=None, sleep=0):
        if test:
            num_episode = self.testing_episodes
            if load_network_ep is not None:
                print("Loading network parameters...")
                self._load_network(ep=load_network_ep)
            print("Start testing...")
            for ep in range(num_episode):
                ep_return = self._interact(render, test, sleep=sleep)
                self.statistic_dict['episode_return'].append(ep_return)
                print("Episode %i" % ep, "return %0.1f" % ep_return)
            print("Finished testing")
        else:
            print("Start training...")
            step_returns = 0
            while self.env_step_count < self.max_env_step:
                ep_return = self._interact(render, test, sleep=sleep)
                step_returns += ep_return
                if self.env_step_count % 1000 == 0:
                    # cumulative rewards every 1000 env steps
                    self.statistic_dict['env_step_return'].append(step_returns)
                    print(
                        "Env step %i" % self.env_step_count,
                        "avg return %0.1f" %
                        self.statistic_dict['env_step_return'][-1])
                    step_returns = 0

                if (self.env_step_count % self.testing_gap
                        == 0) and (self.env_step_count != 0) and (not test):
                    ep_test_return = []
                    for test_ep in range(self.testing_episodes):
                        ep_test_return.append(self._interact(render,
                                                             test=True))
                    self.statistic_dict['env_step_test_return'].append(
                        sum(ep_test_return) / self.testing_episodes)
                    print(
                        "Env step %i" % self.env_step_count,
                        "test return %0.1f" %
                        (sum(ep_test_return) / self.testing_episodes))

                if (self.env_step_count % self.saving_gap
                        == 0) and (self.env_step_count != 0) and (not test):
                    self._save_network(step=self.env_step_count)

            print("Finished training")
            print("Saving statistics...")
            self._plot_statistics(x_labels={
                'env_step_return':
                'Environment step (x1e3)',
                'env_step_test_return':
                'Environment step (x1e4)'
            },
                                  save_to_file=True)

    def _interact(self, render=False, test=False, sleep=0):
        done = False
        obs = self.env.reset()
        # build frame buffer for frame stack observations
        ep_return = 0
        # start a new episode
        while not done:
            if render:
                self.env.render()
            if self.env_step_count < self.warmup_step:
                action = self.env.action_space.sample()
            else:
                action = self._select_action(obs, test=test)
            new_obs, reward, done, info = self.env.step(action)
            time.sleep(sleep)
            ep_return += reward
            if not test:
                self._remember(obs, action, new_obs, reward, 1 - int(done))
                if (self.env_step_count % self.update_interval
                        == 0) and (self.env_step_count > self.warmup_step):
                    self._learn()
                self.env_step_count += 1
                if self.env_step_count % 1000 == 0:
                    break
            obs = new_obs
        return ep_return

    def _select_action(self, obs, test=False):
        obs = T.as_tensor([obs], dtype=T.float, device=self.device)
        return self.network_dict['actor'].get_action(
            obs, mean_pi=test).detach().cpu().numpy()[0]

    def _learn(self, steps=None):
        if len(self.buffer) < self.batch_size:
            return
        if steps is None:
            steps = self.optimizer_steps

        for i in range(steps):
            if self.prioritised:
                batch, weights, inds = self.buffer.sample(self.batch_size)
                weights = T.as_tensor(weights, device=self.device).view(
                    self.batch_size, 1)
            else:
                batch = self.buffer.sample(self.batch_size)
                weights = T.ones(size=(self.batch_size, 1), device=self.device)
                inds = None

            vanilla_actor_inputs = T.as_tensor(batch.state,
                                               dtype=T.float32,
                                               device=self.device)
            actions = T.as_tensor(batch.action,
                                  dtype=T.float32,
                                  device=self.device)
            vanilla_actor_inputs_ = T.as_tensor(batch.next_state,
                                                dtype=T.float32,
                                                device=self.device)
            rewards = T.as_tensor(batch.reward,
                                  dtype=T.float32,
                                  device=self.device).unsqueeze(1)
            done = T.as_tensor(batch.done, dtype=T.float32,
                               device=self.device).unsqueeze(1)

            if self.discard_time_limit:
                done = done * 0 + 1

            average_value_target = 0
            for _ in range(self.q_regularisation_k):
                actor_inputs_ = self.image_random_shift(vanilla_actor_inputs_)
                with T.no_grad():
                    actions_, log_probs_ = self.network_dict[
                        'actor'].get_action(actor_inputs_, probs=True)
                    value_1_ = self.network_dict['critic_1_target'](
                        actor_inputs_, actions_)
                    value_2_ = self.network_dict['critic_2_target'](
                        actor_inputs_, actions_)
                    value_ = T.min(value_1_, value_2_) - (
                        self.network_dict['alpha'] * log_probs_)
                    average_value_target = average_value_target + (
                        rewards + done * self.gamma * value_)
            value_target = average_value_target / self.q_regularisation_k

            self.critic_1_optimizer.zero_grad()
            self.critic_2_optimizer.zero_grad()
            aggregated_critic_loss_1 = 0
            aggregated_critic_loss_2 = 0
            for _ in range(self.q_regularisation_k):
                actor_inputs = self.image_random_shift(vanilla_actor_inputs)

                value_estimate_1 = self.network_dict['critic_1'](actor_inputs,
                                                                 actions)
                critic_loss_1 = F.mse_loss(value_estimate_1,
                                           value_target.detach(),
                                           reduction='none')
                aggregated_critic_loss_1 = aggregated_critic_loss_1 + critic_loss_1

                value_estimate_2 = self.network_dict['critic_2'](actor_inputs,
                                                                 actions)
                critic_loss_2 = F.mse_loss(value_estimate_2,
                                           value_target.detach(),
                                           reduction='none')
                aggregated_critic_loss_2 = aggregated_critic_loss_2 + critic_loss_2

            # backward the both losses before calling .step(), or it will throw CudaRuntime error
            avg_critic_loss_1 = aggregated_critic_loss_1 / self.q_regularisation_k
            (avg_critic_loss_1 * weights).mean().backward()
            avg_critic_loss_2 = aggregated_critic_loss_2 / self.q_regularisation_k
            (avg_critic_loss_2 * weights).mean().backward()
            self.critic_1_optimizer.step()
            self.critic_2_optimizer.step()

            if self.prioritised:
                assert inds is not None
                avg_critic_loss_1 = avg_critic_loss_1.detach().cpu().numpy()
                self.buffer.update_priority(inds, np.abs(avg_critic_loss_1))

            self.statistic_dict['critic_loss'].append(
                avg_critic_loss_1.mean().detach())

            if self.optim_step_count % self.critic_target_update_interval == 0:
                self._soft_update(self.network_dict['critic_1'],
                                  self.network_dict['critic_1_target'])
                self._soft_update(self.network_dict['critic_2'],
                                  self.network_dict['critic_2_target'])

            if self.optim_step_count % self.actor_update_interval == 0:
                self.actor_optimizer.zero_grad()
                self.alpha_optimizer.zero_grad()
                aggregated_actor_loss = 0
                aggregated_alpha_loss = 0
                aggregated_log_probs = 0
                for _ in range(self.q_regularisation_k):
                    actor_inputs = self.image_random_shift(
                        vanilla_actor_inputs)
                    new_actions, new_log_probs = self.network_dict[
                        'actor'].get_action(actor_inputs, probs=True)
                    aggregated_log_probs = aggregated_log_probs + new_log_probs
                    new_values = T.min(
                        self.network_dict['critic_1'](actor_inputs,
                                                      new_actions),
                        self.network_dict['critic_2'](actor_inputs,
                                                      new_actions))
                    aggregated_actor_loss = aggregated_actor_loss + (
                        self.network_dict['alpha'] * new_log_probs -
                        new_values).mean()
                    aggregated_alpha_loss = aggregated_alpha_loss + (
                        self.network_dict['log_alpha'] *
                        (-new_log_probs -
                         self.target_entropy).detach()).mean()

                avg_actor_loss = aggregated_actor_loss / self.q_regularisation_k
                avg_actor_loss.backward()
                avg_alpha_loss = aggregated_alpha_loss / self.q_regularisation_k
                avg_alpha_loss.backward()
                self.actor_optimizer.step()
                self.alpha_optimizer.step()
                self.network_dict['alpha'] = self.network_dict[
                    'log_alpha'].exp()

                self.statistic_dict['actor_loss'].append(
                    avg_actor_loss.mean().detach())
                self.statistic_dict['alpha'].append(
                    self.network_dict['alpha'].detach())
                self.statistic_dict['policy_entropy'].append(
                    (-aggregated_log_probs /
                     self.q_regularisation_k).mean().detach())

            self.optim_step_count += 1
Ejemplo n.º 19
0
class DQN(object):
    def __init__(self, state_shape, action_shape, action_value_model, **args):
        super(DQN, self).__init__()
        self.state_shape = state_shape
        self.action_shape = action_shape
        self.device = args.get("Device",
                               "cuda" if torch.cuda.is_available() else "cpu")
        self.gamma = args.get("Gamma", 0.95)
        self.batch_size = args.get("BatchSize", 64)
        self.action_value_model = action_value_model.to(self.device)
        self.optimizer = Adam(self.action_value_model.parameters(),
                              lr=args.get("LearningRate", 1e-3))
        self.rpm = Buffer(args.get("ReplayMemorySize", 10000))

    def update_model(self):
        # sample a batch update the model
        self.action_value_model.train()
        sample_batch = self.rpm.sample_batch(self.batch_size)
        states, actions, next_states, rewards, done = (
            self.to_tensor(value) for value in decode_batch(sample_batch))
        target_values = self.action_value_model(self.norm(next_states)).max(
            1)[0].detach() * self.gamma * (1 - done) + rewards
        current_values = (self.action_value_model(self.norm(states)) *
                          actions).sum(1)
        td_errors = target_values - current_values

        loss = (td_errors**2).mean()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss

    def choose_action(self, state, epsilon):
        # epsilon-greedy policy
        states = state[np.newaxis, :]
        random_actions = self.random_action(states)
        if epsilon == 1:
            return random_actions[0]
        greedy_actions = self.greedy_action(states)
        actions = [
            ra if np.random.rand() < epsilon else ga
            for ra, ga in zip(random_actions, greedy_actions)
        ]
        actions = np.array(actions)
        return actions[0]

    def random_action(self, state):
        out = np.random.rand(state.shape[0], self.action_shape)
        out[:, 1] *= 0.2
        actions = decode_action(out)
        return actions

    def greedy_action(self, state):
        state_tensor = self.to_tensor(state)
        out = self.to_array(self.action_value_model(self.norm(state_tensor)))
        actions = decode_action(out)
        return actions

    def observe_state(self, trans_tuples):
        # trans_tuples: <state, action, next_state, reward, done>
        self.rpm.append(trans_tuples)

    def norm(self, state):
        # roughly normalize the state variable
        normed = (state.permute([0, 2, 3, 1]) - torch.FloatTensor([40.27] * 4).to(self.device)) / \
                 torch.FloatTensor([40.27] * 4).to(self.device)
        return normed.permute([0, 3, 1, 2])

    def train(self):
        # switch to train mode
        self.action_value_model.train()

    def test(self):
        # switch to test mode
        self.action_value_model.eval()

    def save_state_dict(self, path):
        torch.save(self.action_value_model.state_dict(), path)
        return True

    def load_state_dict(self, path):
        self.action_value_model.load_state_dict(torch.load(path))
        return True

    def to_tensor(self, array):
        return torch.tensor(array).float().to(self.device)

    def to_array(self, tensor):
        if self.device == "cpu":
            return tensor.data.numpy()
        else:
            return tensor.to("cpu").data.numpy()
Ejemplo n.º 20
0
def main():
  """ train model
  """
  try:
    os.makedirs(opt.checkpoints_dir)
  except OSError:
    pass
  ################################################
  #               load train dataset
  ################################################
  dataset = dset.ImageFolder(root=opt.dataroot,
                             transform=transforms.Compose([
                               transforms.Resize(opt.img_size),
                               transforms.ToTensor(),
                               transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
                             ]))

  assert dataset
  dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size,
                                           shuffle=True, num_workers=int(opt.workers))

  if torch.cuda.device_count() > 1:
    netG = torch.nn.DataParallel(Generator())
  else:
    netG = Generator()
  if os.path.exists(opt.netG):
    netG.load_state_dict(torch.load(opt.netG, map_location=lambda storage, loc: storage))

  if torch.cuda.device_count() > 1:
    netD = torch.nn.DataParallel(Discriminator())
  else:
    netD = Discriminator()
  if os.path.exists(opt.netD):
    netD.load_state_dict(torch.load(opt.netD, map_location=lambda storage, loc: storage))

  # set train mode
  netG.train()
  netG = netG.to(device)
  netD.train()
  netD = netD.to(device)
  print(netG)
  print(netD)

  ################################################
  #            Use RMSprop optimizer
  ################################################
  optimizerD = Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
  optimizerG = Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))

  ################################################
  #               print args
  ################################################
  print("########################################")
  print(f"train dataset path: {opt.dataroot}")
  print(f"batch size: {opt.batch_size}")
  print(f"image size: {opt.img_size}")
  print(f"Epochs: {opt.n_epochs}")
  print(f"Noise size: {opt.nz}")
  print("########################################")
  print("Starting trainning!")
  for epoch in range(opt.n_epochs):
    for i, data in enumerate(dataloader):
      # get data
      real_imgs = data[0].to(device)
      batch_size = real_imgs.size(0)

      # Sample noise as generator input
      z = torch.randn(batch_size, opt.nz, 1, 1, device=device)

      ##############################################
      # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
      ##############################################

      optimizerD.zero_grad()

      # Generate a batch of images
      fake_imgs = netG(z)

      real_validity = netD(real_imgs)
      fake_validity = netD(fake_imgs)

      # Gradient penalty
      gradient_penalty = calculate_gradient_penatly(netD, real_imgs.data, fake_imgs.data)

      # Loss measures generator's ability to fool the discriminator
      loss_D = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty * 10

      loss_D.backward()
      optimizerD.step()

      optimizerG.zero_grad()

      ##############################################
      # (2) Update G network: maximize log(D(G(z)))
      ##############################################
      if i % opt.n_critic == 0:
        # Generate a batch of images
        fake_imgs = netG(z)

        # Train on fake images
        loss_G = -torch.mean(netD(fake_imgs))

        loss_G.backward()
        optimizerG.step()

        print(f"Epoch->[{epoch + 1:03d}/{opt.n_epochs:03d}] "
              f"Progress->{i / len(dataloader) * 100:4.2f}% "
              f"Loss_D: {loss_D.item():.4f} "
              f"Loss_G: {loss_G.item():.4f} ", end="\r")

      if i % 50 == 0:
        vutils.save_image(real_imgs, f"{opt.out_images}/real_samples.png", normalize=True)
        with torch.no_grad():
          fake = netG(fixed_noise).detach().cpu()
        vutils.save_image(fake, f"{opt.out_images}/fake_samples_epoch_{epoch + 1}.png", normalize=True)

    # do checkpointing
    torch.save(netG.state_dict(), opt.netG)
    torch.save(netD.state_dict(), opt.netD)
Ejemplo n.º 21
0
class ExperimentBuilder(nn.Module):
    def __init__(self,
                 network_model,
                 experiment_name,
                 num_epochs,
                 train_data,
                 val_data,
                 test_data,
                 weight_decay_coefficient,
                 use_gpu,
                 continue_from_epoch=-1):
        """
        Initializes an ExperimentBuilder object. Such an object takes care of running training and evaluation of a deep net
        on a given dataset. It also takes care of saving per epoch models and automatically inferring the best val model
        to be used for evaluating the test set metrics.
        :param network_model: A pytorch nn.Module which implements a network architecture.
        :param experiment_name: The name of the experiment. This is used mainly for keeping track of the experiment and creating and directory structure that will be used to save logs, model parameters and other.
        :param num_epochs: Total number of epochs to run the experiment
        :param train_data: An object of the DataProvider type. Contains the training set.
        :param val_data: An object of the DataProvider type. Contains the val set.
        :param test_data: An object of the DataProvider type. Contains the test set.
        :param weight_decay_coefficient: A float indicating the weight decay to use with the adam optimizer.
        :param use_gpu: A boolean indicating whether to use a GPU or not.
        :param continue_from_epoch: An int indicating whether we'll start from scrach (-1) or whether we'll reload a previously saved model of epoch 'continue_from_epoch' and continue training from there.
        """
        super(ExperimentBuilder, self).__init__()

        self.experiment_name = experiment_name
        self.model = network_model
        self.device = torch.cuda.current_device()

        if torch.cuda.device_count() > 1 and use_gpu:
            self.device = torch.cuda.current_device()
            self.model.to(self.device)
            self.model = nn.DataParallel(module=self.model)
            print('Use Multi GPU', self.device)
        elif torch.cuda.device_count() == 1 and use_gpu:
            self.device = torch.cuda.current_device()
            self.model.to(
                self.device)  # sends the model from the cpu to the gpu
            print('Use GPU', self.device)
        else:
            print("use CPU")
            self.device = torch.device('cpu')  # sets the device to be CPU
            print(self.device)

        # re-initialize network parameters
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.optimizer = Adam(self.parameters(),
                              amsgrad=False,
                              weight_decay=weight_decay_coefficient)

        print('System learnable parameters')
        num_conv_layers = 0
        num_linear_layers = 0
        total_num_parameters = 0
        for name, value in self.named_parameters():
            print(name, value.shape)
            if all(item in name for item in ['conv', 'weight']):
                num_conv_layers += 1
            if all(item in name for item in ['linear', 'weight']):
                num_linear_layers += 1
            total_num_parameters += np.prod(value.shape)

        print('Total number of parameters', total_num_parameters)
        print('Total number of conv layers', num_conv_layers)
        print('Total number of linear layers', num_linear_layers)

        # Generate the directory names
        self.experiment_folder = os.path.abspath(experiment_name)
        self.experiment_logs = os.path.abspath(
            os.path.join(self.experiment_folder, "result_outputs"))
        self.experiment_saved_models = os.path.abspath(
            os.path.join(self.experiment_folder, "saved_models"))
        print(self.experiment_folder, self.experiment_logs)
        # Set best models to be at 0 since we are just starting
        self.best_val_model_idx = 0
        self.best_val_model_acc = 0.

        if not os.path.exists(self.experiment_folder
                              ):  # If experiment directory does not exist
            os.mkdir(self.experiment_folder)  # create the experiment directory

        if not os.path.exists(self.experiment_logs):
            os.mkdir(
                self.experiment_logs)  # create the experiment log directory

        if not os.path.exists(self.experiment_saved_models):
            os.mkdir(self.experiment_saved_models
                     )  # create the experiment saved models directory

        self.num_epochs = num_epochs
        self.criterion = nn.CrossEntropyLoss().to(
            self.device)  # send the loss computation to the GPU
        if continue_from_epoch == -2:
            try:
                self.best_val_model_idx, self.best_val_model_acc, self.state = self.load_model(
                    model_save_dir=self.experiment_saved_models,
                    model_save_name="train_model",
                    model_idx='latest'
                )  # reload existing model from epoch and return best val model index
                # and the best val acc of that model
                self.starting_epoch = self.state['current_epoch_idx']
            except:
                print(
                    "Model objects cannot be found, initializing a new model and starting from scratch"
                )
                self.starting_epoch = 0
                self.state = dict()

        elif continue_from_epoch != -1:  # if continue from epoch is not -1 then
            self.best_val_model_idx, self.best_val_model_acc, self.state = self.load_model(
                model_save_dir=self.experiment_saved_models,
                model_save_name="train_model",
                model_idx=continue_from_epoch
            )  # reload existing model from epoch and return best val model index
            # and the best val acc of that model
            self.starting_epoch = self.state['current_epoch_idx']
        else:
            self.starting_epoch = 0
            self.state = dict()

    def get_num_parameters(self):
        total_num_params = 0
        for param in self.parameters():
            total_num_params += np.prod(param.shape)

        return total_num_params

    def run_train_iter(self, x, y):
        """
        Receives the inputs and targets for the model and runs a training iteration. Returns loss and accuracy metrics.
        :param x: The inputs to the model. A numpy array of shape batch_size, channels, height, width
        :param y: The targets for the model. A numpy array of shape batch_size, num_classes
        :return: the loss and accuracy for this batch
        """
        self.train(
        )  # sets model to training mode (in case batch normalization or other methods have different procedures for training and evaluation)

        if len(y.shape) > 1:
            y = np.argmax(
                y, axis=1
            )  # convert one hot encoded labels to single integer labels

        #print(type(x))

        if type(x) is np.ndarray:
            x, y = torch.Tensor(x).float().to(
                device=self.device), torch.Tensor(y).long().to(
                    device=self.device)  # send data to device as torch tensors

        x = x.to(self.device)
        y = y.to(self.device)

        out = self.model.forward(x)  # forward the data in the model
        loss = F.cross_entropy(input=out, target=y)  # compute loss

        self.optimizer.zero_grad(
        )  # set all weight grads from previous training iters to 0
        loss.backward(
        )  # backpropagate to compute gradients for current iter loss

        self.optimizer.step()  # update network parameters
        _, predicted = torch.max(out.data, 1)  # get argmax of predictions
        accuracy = np.mean(list(predicted.eq(
            y.data).cpu()))  # compute accuracy
        return loss.data.detach().cpu().numpy(), accuracy

    def run_evaluation_iter(self, x, y):
        """
        Receives the inputs and targets for the model and runs an evaluation iterations. Returns loss and accuracy metrics.
        :param x: The inputs to the model. A numpy array of shape batch_size, channels, height, width
        :param y: The targets for the model. A numpy array of shape batch_size, num_classes
        :return: the loss and accuracy for this batch
        """
        self.eval()  # sets the system to validation mode
        if len(y.shape) > 1:
            y = np.argmax(
                y, axis=1
            )  # convert one hot encoded labels to single integer labels
        if type(x) is np.ndarray:
            x, y = torch.Tensor(x).float(
            ).to(device=self.device), torch.Tensor(y).long().to(
                device=self.device
            )  # convert data to pytorch tensors and send to the computation device

        x = x.to(self.device)
        y = y.to(self.device)
        out = self.model.forward(x)  # forward the data in the model
        loss = F.cross_entropy(out, y)  # compute loss
        _, predicted = torch.max(out.data, 1)  # get argmax of predictions
        accuracy = np.mean(list(predicted.eq(
            y.data).cpu()))  # compute accuracy
        return loss.data.detach().cpu().numpy(), accuracy

    def run_inference_iter(self, x):
        """
        Receives the inputs and targets for the model and runs an evaluation iterations. Returns loss and accuracy metrics.
        :param x: The inputs to the model. A numpy array of shape batch_size, channels, height, width
        :param y: The targets for the model. A numpy array of shape batch_size, num_classes
        :return: the loss and accuracy for this batch
        """
        self.eval()  # sets the system to validation mode

        if type(x) is np.ndarray:
            x = torch.Tensor(x).float().to(device=self.device)
            # convert data to pytorch tensors and send to the computation device

        x = x.to(self.device)

        logits_out = self.model.forward(
            x)  # forward the data in the model and return logits
        probility_distribution_out = F.softmax(
            logits_out, dim=1)  # spits out probability distribution

        _, argmax_out = torch.max(logits_out.data,
                                  1)  # get argmax of predictions

        return logits_out, probility_distribution_out, argmax_out

    def save_model(self, model_save_dir, model_save_name, model_idx, state):
        """
        Save the network parameter state and current best val epoch idx and best val accuracy.
        :param model_save_name: Name to use to save model without the epoch index
        :param model_idx: The index to save the model with.
        :param best_validation_model_idx: The index of the best validation model to be stored for future use.
        :param best_validation_model_acc: The best validation accuracy to be stored for use at test time.
        :param model_save_dir: The directory to store the state at.
        :param state: The dictionary containing the system state.

        """
        state['network'] = self.state_dict(
        )  # save network parameter and other variables.
        torch.save(
            state,
            f=os.path.join(model_save_dir, "{}_{}".format(
                model_save_name,
                str(model_idx))))  # save state at prespecified filepath

    def run_training_epoch(self, current_epoch_losses):
        with tqdm.tqdm(total=len(self.train_data), file=sys.stdout
                       ) as pbar_train:  # create a progress bar for training
            for idx, (x, y) in enumerate(self.train_data):  # get data batches
                loss, accuracy = self.run_train_iter(
                    x=x, y=y)  # take a training iter step
                current_epoch_losses["train_loss"].append(
                    loss)  # add current iter loss to the train loss list
                current_epoch_losses["train_acc"].append(
                    accuracy)  # add current iter acc to the train acc list
                pbar_train.update(1)
                pbar_train.set_description(
                    "loss: {:.4f}, accuracy: {:.4f}".format(loss, accuracy))

        return current_epoch_losses

    def run_validation_epoch(self, current_epoch_losses):

        with tqdm.tqdm(total=len(self.val_data), file=sys.stdout
                       ) as pbar_val:  # create a progress bar for validation
            for x, y in self.val_data:  # get data batches
                loss, accuracy = self.run_evaluation_iter(
                    x=x, y=y)  # run a validation iter
                current_epoch_losses["val_loss"].append(
                    loss)  # add current iter loss to val loss list.
                current_epoch_losses["val_acc"].append(
                    accuracy)  # add current iter acc to val acc lst.
                pbar_val.update(1)  # add 1 step to the progress bar
                pbar_val.set_description(
                    "loss: {:.4f}, accuracy: {:.4f}".format(loss, accuracy))

        return current_epoch_losses

    def run_testing_epoch(self, current_epoch_losses):

        with tqdm.tqdm(total=len(self.test_data),
                       file=sys.stdout) as pbar_test:  # ini a progress bar
            for x, y in self.test_data:  # sample batch
                loss, accuracy = self.run_evaluation_iter(
                    x=x, y=y
                )  # compute loss and accuracy by running an evaluation step
                current_epoch_losses["test_loss"].append(
                    loss)  # save test loss
                current_epoch_losses["test_acc"].append(
                    accuracy)  # save test accuracy
                pbar_test.update(1)  # update progress bar status
                pbar_test.set_description(
                    "loss: {:.4f}, accuracy: {:.4f}".format(
                        loss, accuracy))  # update progress bar string output
        return current_epoch_losses

    def load_model(self, model_save_dir, model_save_name, model_idx):
        """
        Load the network parameter state and the best val model idx and best val acc to be compared with the future val accuracies, in order to choose the best val model
        :param model_save_dir: The directory to store the state at.
        :param model_save_name: Name to use to save model without the epoch index
        :param model_idx: The index to save the model with.
        :return: best val idx and best val model acc, also it loads the network state into the system state without returning it
        """
        state = torch.load(f=os.path.join(
            model_save_dir, "{}_{}".format(model_save_name, str(model_idx))))
        self.load_state_dict(state_dict=state['network'])
        return state['best_val_model_idx'], state['best_val_model_acc'], state

    def run_experiment(self):
        """
        Runs experiment train and evaluation iterations, saving the model and best val model and val model accuracy after each epoch
        :return: The summary current_epoch_losses from starting epoch to total_epochs.
        """
        total_losses = {
            "train_acc": [],
            "train_loss": [],
            "val_acc": [],
            "val_loss": [],
            "curr_epoch": []
        }  # initialize a dict to keep the per-epoch metrics
        for i, epoch_idx in enumerate(
                range(self.starting_epoch, self.num_epochs)):
            epoch_start_time = time.time()
            current_epoch_losses = {
                "train_acc": [],
                "train_loss": [],
                "val_acc": [],
                "val_loss": []
            }

            current_epoch_losses = self.run_training_epoch(
                current_epoch_losses)
            current_epoch_losses = self.run_validation_epoch(
                current_epoch_losses)

            val_mean_accuracy = np.mean(current_epoch_losses['val_acc'])

            if val_mean_accuracy > self.best_val_model_acc:  # if current epoch's mean val acc is greater than the saved best val acc then
                self.best_val_model_acc = val_mean_accuracy  # set the best val model acc to be current epoch's val accuracy
                self.best_val_model_idx = epoch_idx  # set the experiment-wise best val idx to be the current epoch's idx

            for key, value in current_epoch_losses.items():
                total_losses[key].append(np.mean(value))
                # get mean of all metrics of current epoch metrics dict,
                # to get them ready for storage and output on the terminal.

            total_losses['curr_epoch'].append(epoch_idx)
            save_statistics(experiment_log_dir=self.experiment_logs,
                            filename='summary.csv',
                            stats_dict=total_losses,
                            current_epoch=i,
                            continue_from_mode=True if
                            (self.starting_epoch != 0 or i > 0) else
                            False)  # save statistics to stats file.

            # load_statistics(experiment_log_dir=self.experiment_logs, filename='summary.csv') # How to load a csv file if you need to

            out_string = "_".join([
                "{}_{:.4f}".format(key, np.mean(value))
                for key, value in current_epoch_losses.items()
            ])
            # create a string to use to report our epoch metrics
            epoch_elapsed_time = time.time(
            ) - epoch_start_time  # calculate time taken for epoch
            epoch_elapsed_time = "{:.4f}".format(epoch_elapsed_time)
            print("Epoch {}:".format(epoch_idx), out_string, "epoch time",
                  epoch_elapsed_time, "seconds")
            self.state['current_epoch_idx'] = epoch_idx
            self.state['best_val_model_acc'] = self.best_val_model_acc
            self.state['best_val_model_idx'] = self.best_val_model_idx
            self.save_model(
                model_save_dir=self.experiment_saved_models,
                # save model and best val idx and best val acc, using the model dir, model name and model idx
                model_save_name="train_model",
                model_idx=epoch_idx,
                state=self.state)
            self.save_model(
                model_save_dir=self.experiment_saved_models,
                # save model and best val idx and best val acc, using the model dir, model name and model idx
                model_save_name="train_model",
                model_idx='latest',
                state=self.state)

        print("Generating test set evaluation metrics")
        self.load_model(
            model_save_dir=self.experiment_saved_models,
            model_idx=self.best_val_model_idx,
            # load best validation model
            model_save_name="train_model")
        current_epoch_losses = {
            "test_acc": [],
            "test_loss": []
        }  # initialize a statistics dict

        current_epoch_losses = self.run_testing_epoch(
            current_epoch_losses=current_epoch_losses)

        test_losses = {
            key: [np.mean(value)]
            for key, value in current_epoch_losses.items()
        }  # save test set metrics in dict format

        save_statistics(
            experiment_log_dir=self.experiment_logs,
            filename='test_summary.csv',
            # save test set metrics on disk in .csv format
            stats_dict=test_losses,
            current_epoch=0,
            continue_from_mode=False)

        return total_losses, test_losses
Ejemplo n.º 22
0
class DRCNetworks(ConvolutionalNeuralNetworks):
    '''
    Deep Reconstruction-Classification Networks(DRCN or DRCNetworks).

    Deep Reconstruction-Classification Network(DRCN or DRCNetworks) is a convolutional network 
    that jointly learns two tasks: 
    
    1. supervised source label prediction.
    2. unsupervised target data reconstruction. 

    Ideally, a discriminative representation should model both the label and 
    the structure of the data. Based on that intuition, Ghifary, M., et al.(2016) hypothesize 
    that a domain-adaptive representation should satisfy two criteria:
    
    1. classify well the source domain labeled data.
    2. reconstruct well the target domain unlabeled data, which can be viewed as an approximate of the ideal discriminative representation.

    The encoding parameters of the DRCN are shared across both tasks, 
    while the decoding parameters are sepa-rated. The aim is that the learned label 
    prediction function can perform well onclassifying images in the target domain
    thus the data reconstruction can beviewed as an auxiliary task to support the 
    adaptation of the label prediction.

    References:
        - Ghifary, M., Kleijn, W. B., Zhang, M., Balduzzi, D., & Li, W. (2016, October). Deep reconstruction-classification networks for unsupervised domain adaptation. In European Conference on Computer Vision (pp. 597-613). Springer, Cham.
    '''

    # `bool` that means initialization in this class will be deferred or not.
    __init_deferred_flag = False

    # `list` of losses.
    __loss_list = []

    # `list` of accuracies.
    __acc_list = []

    def __init__(
        self,
        convolutional_auto_encoder,
        drcn_loss,
        initializer_f=None,
        optimizer_f=None,
        auto_encoder_optimizer_f=None,
        learning_rate=1e-05,
        hidden_units_list=[],
        output_nn=None,
        hidden_dropout_rate_list=[],
        hidden_activation_list=[],
        hidden_batch_norm_list=[],
        ctx="cpu",
        regularizatable_data_list=[],
        scale=1.0,
        tied_weights_flag=True,
        est=1e-08,
        wd=0.0,
        not_init_flag=False,
    ):
        '''
        Init.

        Args:
            convolutional_auto_encoder:     is-a `ConvolutionalAutoEncoder`.
            drcn_loss:                      is-a `DRCNLoss`.
            initializer:                    is-a `mxnet.initializer.Initializer` for parameters of model. If `None`, it is drawing from the Xavier distribution.
            learning_rate:                  `float` of learning rate.
            learning_attenuate_rate:        `float` of attenuate the `learning_rate` by a factor of this value every `attenuate_epoch`.
            attenuate_epoch:                `int` of attenuate the `learning_rate` by a factor of `learning_attenuate_rate` every `attenuate_epoch`.
                                            

            hidden_units_list:              `list` of `mxnet.gluon.nn._conv` in hidden layers.
            output_nn:                      is-a `NeuralNetworks` as output layers.
                                            If `None`, last layer in `hidden_units_list` will be considered as an output layer.

            hidden_dropout_rate_list:       `list` of `float` of dropout rate in hidden layers.

            optimizer_name:                 `str` of name of optimizer.

            hidden_activation_list:         `list` of act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` in input gate.
            hidden_batch_norm_list:         `list` of `mxnet.gluon.nn.BatchNorm` in hidden layers.

            ctx:                            `mx.cpu()` or `mx.gpu()`.
            hybridize_flag:                  `bool` of flag that means this class will call `mxnet.gluon.HybridBlock.hybridize()` or not.
            regularizatable_data_list:           `list` of `Regularizatable`.
            scale:                          `float` of scaling factor for initial parameters.
            tied_weights_flag:              `bool` of flag to tied weights or not.

            est:                            `float` of the early stopping.
                                            If it is not `None`, the learning will be stopped 
                                            when (|loss -previous_loss| < est).

        '''
        if isinstance(convolutional_auto_encoder,
                      ConvolutionalAutoEncoder) is False:
            raise TypeError(
                "The type of `convolutional_auto_encoder` must be `ConvolutionalAutoEncoder`."
            )

        if len(hidden_units_list) != len(hidden_activation_list):
            raise ValueError(
                "The length of `hidden_units_list` and `hidden_activation_list` must be equivalent."
            )

        if len(hidden_dropout_rate_list) != len(hidden_units_list):
            raise ValueError(
                "The length of `hidden_dropout_rate_list` and `hidden_units_list` must be equivalent."
            )

        if isinstance(drcn_loss, DRCNLoss) is False:
            raise TypeError("The type of `drcn_loss` must be `DRCNLoss`.")

        logger = getLogger("accelbrainbase")
        self.__logger = logger
        init_deferred_flag = self.init_deferred_flag
        self.init_deferred_flag = True
        super().__init__(
            computable_loss=drcn_loss,
            initializer_f=initializer_f,
            learning_rate=learning_rate,
            hidden_units_list=hidden_units_list,
            output_nn=None,
            hidden_dropout_rate_list=hidden_dropout_rate_list,
            optimizer_f=optimizer_f,
            hidden_activation_list=hidden_activation_list,
            hidden_batch_norm_list=hidden_batch_norm_list,
            ctx=ctx,
            regularizatable_data_list=regularizatable_data_list,
            scale=scale,
        )
        self.init_deferred_flag = init_deferred_flag
        self.convolutional_auto_encoder = convolutional_auto_encoder
        self.__tied_weights_flag = tied_weights_flag

        self.output_nn = output_nn

        for v in regularizatable_data_list:
            if isinstance(v, RegularizatableData) is False:
                raise TypeError(
                    "The type of values of `regularizatable_data_list` must be `RegularizatableData`."
                )
        self.__regularizatable_data_list = regularizatable_data_list

        self.drcn_loss = drcn_loss

        self.__learning_rate = learning_rate

        self.__ctx = ctx

        self.__loss_list = []
        self.__acc_list = []
        self.__target_domain_arr = None

        self.__tied_weights_flag = tied_weights_flag

        self.__est = est

        self.__not_init_flag = not_init_flag

        for i in range(len(hidden_units_list)):
            if initializer_f is None:
                hidden_units_list[i].weight = torch.nn.init.xavier_normal_(
                    hidden_units_list[i].weight, gain=1.0)
            else:
                hidden_units_list[i].weight = initializer_f(
                    hidden_units_list[i].weight)

        if self.init_deferred_flag is False:
            if self.__not_init_flag is False:
                if auto_encoder_optimizer_f is not None:
                    self.convolutional_auto_encoder.encoder_optimizer = auto_encoder_optimizer_f(
                        self.convolutional_auto_encoder.encoder.parameters(), )
                    self.convolutional_auto_encoder.decoder_optimizer = auto_encoder_optimizer_f(
                        self.convolutional_auto_encoder.decoder.parameters(), )
                elif optimizer_f is not None:
                    self.convolutional_auto_encoder.encoder_optimizer = optimizer_f(
                        self.convolutional_auto_encoder.encoder.parameters(), )
                    self.convolutional_auto_encoder.decoder_optimizer = optimizer_f(
                        self.convolutional_auto_encoder.decoder.parameters(), )
                else:
                    self.convolutional_auto_encoder.encoder_optimizer = Adam(
                        self.convolutional_auto_encoder.encoder.parameters(),
                        lr=self.__learning_rate)
                    self.convolutional_auto_encoder.decoder_optimizer = Adam(
                        self.convolutional_auto_encoder.decoder.parameters(),
                        lr=self.__learning_rate)

                if optimizer_f is not None:
                    self.optimizer = optimizer_f(self.output_nn.parameters(), )
                elif optimizer_f is not None:
                    self.optimizer = optimizer_f(self.parameters(), )
                else:
                    self.optimizer = Adam(
                        self.parameters(),
                        lr=self.__learning_rate,
                    )

        self.flatten = nn.Flatten()

    def parameters(self):
        '''
        '''
        params_dict_list = [{
            "params":
            self.convolutional_auto_encoder.parameters(),
        }, {
            "params": self.parameters(),
        }]
        if self.output_nn is not None:
            params_dict_list.append({"params": self.output_nn.parameters()})
        return params_dict_list

    def learn(self, iteratable_data):
        '''
        Learn the observed data points with domain adaptation.

        Args:
            iteratable_data:     is-a `DRCNIterator`.

        '''
        if isinstance(iteratable_data, DRCNIterator) is False:
            raise TypeError(
                "The type of `iteratable_data` must be `DRCNIterator`.")

        self.__loss_list = []
        self.__acc_list = []
        learning_rate = self.__learning_rate
        self.__previous_loss = None
        est_flag = False
        try:
            epoch = 0
            iter_n = 0
            for batch_observed_arr, batch_target_arr, test_batch_observed_arr, test_batch_target_arr, target_domain_arr in iteratable_data.generate_learned_samples(
            ):
                self.epoch = epoch
                self.batch_size = batch_observed_arr.shape[0]
                self.convolutional_auto_encoder.batch_size = self.batch_size

                self.convolutional_auto_encoder.encoder_optimizer.zero_grad()
                self.convolutional_auto_encoder.decoder_optimizer.zero_grad()
                self.optimizer.zero_grad()

                # rank-3
                decoded_arr = self.inference_auto_encoder(target_domain_arr)
                _, prob_arr = self.inference(batch_observed_arr)
                loss, train_classification_loss, train_reconstruction_loss = self.compute_loss(
                    decoded_arr, prob_arr, target_domain_arr, batch_target_arr)

                if self.__previous_loss is not None:
                    loss_diff = loss - self.__previous_loss
                    if self.__est is not None and torch.abs(
                            loss_diff) < self.__est:
                        est_flag = True

                if est_flag is False:
                    loss.backward()
                    self.convolutional_auto_encoder.encoder_optimizer.step()
                    self.convolutional_auto_encoder.decoder_optimizer.step()
                    self.optimizer.step()
                    self.regularize()

                    self.__previous_loss = loss

                    if (iter_n + 1) % int(iteratable_data.iter_n /
                                          iteratable_data.epochs) == 0:
                        with torch.inference_mode():
                            # rank-3
                            test_decoded_arr = self.inference_auto_encoder(
                                target_domain_arr)
                            _, test_prob_arr = self.inference(
                                test_batch_observed_arr)
                            test_loss, test_classification_loss, test_reconstruction_loss = self.compute_loss(
                                test_decoded_arr, test_prob_arr,
                                target_domain_arr, test_batch_target_arr)

                        _loss = loss.to('cpu').detach().numpy().copy()
                        _train_classification_loss = train_classification_loss.to(
                            'cpu').detach().numpy().copy()
                        _train_reconstruction_loss = train_reconstruction_loss.to(
                            'cpu').detach().numpy().copy()
                        _test_loss = test_loss.to(
                            'cpu').detach().numpy().copy()
                        _test_classification_loss = test_classification_loss.to(
                            'cpu').detach().numpy().copy()
                        _test_reconstruction_loss = test_reconstruction_loss.to(
                            'cpu').detach().numpy().copy()

                        self.__logger.debug("Epochs: " + str(epoch + 1) +
                                            " Train total loss: " +
                                            str(_loss) + " Test total loss: " +
                                            str(_test_loss))
                        self.__logger.debug("Train classification loss: " +
                                            str(_train_classification_loss) +
                                            " Test classification loss: " +
                                            str(_test_classification_loss))
                        self.__logger.debug("Train reconstruction loss: " +
                                            str(_train_reconstruction_loss) +
                                            " Test reconstruction loss: " +
                                            str(_test_reconstruction_loss))

                    if self.compute_acc_flag is True:
                        if (iter_n + 1) % int(iteratable_data.iter_n /
                                              iteratable_data.epochs) == 0:
                            acc, inferenced_label_arr, answer_label_arr = self.compute_acc(
                                prob_arr, batch_target_arr)
                            test_acc, test_inferenced_label_arr, test_answer_label_arr = self.compute_acc(
                                test_prob_arr, test_batch_target_arr)
                            if (epoch + 1) % 100 == 0 or epoch < 100:
                                acc, inferenced_label_arr, answer_label_arr = self.compute_acc(
                                    prob_arr, batch_target_arr)
                                test_acc, test_inferenced_label_arr, test_answer_label_arr = self.compute_acc(
                                    test_prob_arr, test_batch_target_arr)

                                self.__logger.debug("-" * 100)
                                self.__logger.debug("Train accuracy: " +
                                                    str(acc) +
                                                    " Test accuracy: " +
                                                    str(test_acc))
                                self.__logger.debug(
                                    "Train infenreced label(inferenced):")
                                self.__logger.debug(
                                    inferenced_label_arr.to(
                                        'cpu').detach().numpy())
                                self.__logger.debug(
                                    "Train infenreced label(answer):")
                                self.__logger.debug(
                                    answer_label_arr.to(
                                        'cpu').detach().numpy())

                                self.__logger.debug(
                                    "Test infenreced label(inferenced):")
                                self.__logger.debug(
                                    test_inferenced_label_arr.to(
                                        'cpu').detach().numpy())
                                self.__logger.debug(
                                    "Test infenreced label(answer):")
                                self.__logger.debug(
                                    test_answer_label_arr.to(
                                        'cpu').detach().numpy())
                                self.__logger.debug("-" * 100)

                                if (
                                        test_answer_label_arr[0]
                                        == test_answer_label_arr
                                ).to('cpu').detach().numpy().astype(int).sum(
                                ) != test_answer_label_arr.shape[0]:
                                    if (
                                            test_inferenced_label_arr[0] ==
                                            test_inferenced_label_arr
                                    ).to('cpu').detach().numpy(
                                    ).astype(int).sum(
                                    ) == test_inferenced_label_arr.shape[0]:
                                        self.__logger.debug(
                                            "It may be overfitting.")

                    if (iter_n + 1) % int(iteratable_data.iter_n /
                                          iteratable_data.epochs) == 0:
                        self.__loss_list.append(
                            (_loss, _test_loss, _train_classification_loss,
                             _test_classification_loss,
                             _train_reconstruction_loss,
                             _test_reconstruction_loss))
                        if self.compute_acc_flag is True:
                            self.__acc_list.append((acc, test_acc))

                    if (iter_n + 1) % int(iteratable_data.iter_n /
                                          iteratable_data.epochs) == 0:
                        epoch += 1
                    iter_n += 1

                else:
                    self.__logger.debug("Early stopping.")
                    break

        except KeyboardInterrupt:
            self.__logger.debug("Interrupt.")

        self.__logger.debug("end. ")

    def forward(self, x):
        '''
        Hybrid forward with Gluon API.

        Args:
            x:      `tensor` of observed data points.
        
        Returns:
            Tuple data.
                - `tensor` of reconstrcted feature points.
                - `tensor` of inferenced label.
        '''
        decoded_arr = self.convolutional_auto_encoder(x)
        self.feature_points_arr = self.convolutional_auto_encoder.feature_points_arr

        if self.output_nn is not None:
            prob_arr = self.output_nn(self.feature_points_arr)
        else:
            prob_arr = self.feature_points_arr

        return decoded_arr, prob_arr

    def inference_auto_encoder(self, x):
        '''
        Hybrid forward with Gluon API (Auto-Encoder only).

        Args:
            x:      `tensor` of observed data points.
        
        Returns:
            `tensor` of reconstrcted feature points.
        '''
        return self.convolutional_auto_encoder(x)

    def regularize(self):
        '''
        Regularization.
        '''
        self.convolutional_auto_encoder.regularize()
        super().regularize()

    def compute_loss(self, decoded_arr, prob_arr, batch_observed_arr,
                     batch_target_arr):
        '''
        Compute loss.

        Args:
            decoded_arr:            `tensor` of decoded feature points..
            prob_arr:               `tensor` of predicted labels data.
            batch_observed_arr:     `tensor` of observed data points.
            batch_target_arr:       `tensor` of label data.

        Returns:
            loss.
        '''
        return self.drcn_loss(decoded_arr, prob_arr, batch_observed_arr,
                              batch_target_arr)

    def set_readonly(self, value):
        ''' setter '''
        raise TypeError("This property must be read-only.")

    def get_loss_arr(self):
        ''' getter for losses. '''
        return np.array(self.__loss_list)

    loss_arr = property(get_loss_arr, set_readonly)

    def get_acc_list(self):
        ''' getter for accuracies. '''
        return np.array(self.__acc_list)

    acc_arr = property(get_acc_list, set_readonly)
Ejemplo n.º 23
0
class Trainer:
    def __init__(self, config, data_loader):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.num_epoch = config.num_epoch
        self.epoch = config.epoch
        self.image_size = config.image_size
        self.data_loader = data_loader
        self.checkpoint_dir = config.checkpoint_dir
        self.batch_size = config.batch_size
        self.sample_dir = config.sample_dir
        self.nf = config.nf
        self.scale_factor = config.scale_factor

        if config.is_perceptual_oriented:
            self.lr = config.p_lr
            self.content_loss_factor = config.p_content_loss_factor
            self.perceptual_loss_factor = config.p_perceptual_loss_factor
            self.adversarial_loss_factor = config.p_adversarial_loss_factor
            self.decay_iter = config.p_decay_iter
        else:
            self.lr = config.g_lr
            self.content_loss_factor = config.g_content_loss_factor
            self.perceptual_loss_factor = config.g_perceptual_loss_factor
            self.adversarial_loss_factor = config.g_adversarial_loss_factor
            self.decay_iter = config.g_decay_iter

        self.build_model()
        self.optimizer_generator = Adam(self.generator.parameters(), lr=self.lr, betas=(config.b1, config.b2),
                                        weight_decay=config.weight_decay)
        self.optimizer_discriminator = Adam(self.discriminator.parameters(), lr=self.lr, betas=(config.b1, config.b2),
                                            weight_decay=config.weight_decay)

        self.lr_scheduler_generator = torch.optim.lr_scheduler.MultiStepLR(self.optimizer_generator, self.decay_iter)
        self.lr_scheduler_discriminator = torch.optim.lr_scheduler.MultiStepLR(self.optimizer_discriminator, self.decay_iter)

    def train(self):
        total_step = len(self.data_loader)
        adversarial_criterion = nn.BCEWithLogitsLoss().to(self.device)
        content_criterion = nn.L1Loss().to(self.device)
        perception_criterion = PerceptualLoss().to(self.device)
        self.generator.train()
        self.discriminator.train()

        for epoch in range(self.epoch, self.num_epoch):
            if not os.path.exists(os.path.join(self.sample_dir, str(epoch))):
                os.makedirs(os.path.join(self.sample_dir, str(epoch)))

            for step, image in enumerate(self.data_loader):
                low_resolution = image['lr'].to(self.device)
                high_resolution = image['hr'].to(self.device)

                real_labels = torch.ones((high_resolution.size(0), 1)).to(self.device)
                fake_labels = torch.zeros((high_resolution.size(0), 1)).to(self.device)

                ##########################
                #   training generator   #
                ##########################
                self.optimizer_generator.zero_grad()
                fake_high_resolution = self.generator(low_resolution)

                score_real = self.discriminator(high_resolution)
                score_fake = self.discriminator(fake_high_resolution)
                discriminator_rf = score_real - score_fake.mean()
                discriminator_fr = score_fake - score_real.mean()

                adversarial_loss_rf = adversarial_criterion(discriminator_rf, fake_labels)
                adversarial_loss_fr = adversarial_criterion(discriminator_fr, real_labels)
                adversarial_loss = (adversarial_loss_fr + adversarial_loss_rf) / 2

                perceptual_loss = perception_criterion(high_resolution, fake_high_resolution)
                content_loss = content_criterion(fake_high_resolution, high_resolution)

                generator_loss = adversarial_loss * self.adversarial_loss_factor + \
                                 perceptual_loss * self.perceptual_loss_factor + \
                                 content_loss * self.content_loss_factor

                generator_loss.backward()
                self.optimizer_generator.step()

                ##########################
                # training discriminator #
                ##########################

                self.optimizer_discriminator.zero_grad()

                score_real = self.discriminator(high_resolution)
                score_fake = self.discriminator(fake_high_resolution.detach())
                discriminator_rf = score_real - score_fake.mean()
                discriminator_fr = score_fake - score_real.mean()

                adversarial_loss_rf = adversarial_criterion(discriminator_rf, real_labels)
                adversarial_loss_fr = adversarial_criterion(discriminator_fr, fake_labels)
                discriminator_loss = (adversarial_loss_fr + adversarial_loss_rf) / 2

                discriminator_loss.backward()
                self.optimizer_discriminator.step()

                self.lr_scheduler_generator.step()
                self.lr_scheduler_discriminator.step()
                if step % 1000 == 0:
                    print(f"[Epoch {epoch}/{self.num_epoch}] [Batch {step}/{total_step}] "
                          f"[D loss {discriminator_loss.item():.4f}] [G loss {generator_loss.item():.4f}] "
                          f"[adversarial loss {adversarial_loss.item() * self.adversarial_loss_factor:.4f}]"
                          f"[perceptual loss {perceptual_loss.item() * self.perceptual_loss_factor:.4f}]"
                          f"[content loss {content_loss.item() * self.content_loss_factor:.4f}]"
                          f"")
                    if step % 5000 == 0:
                        result = torch.cat((high_resolution, fake_high_resolution), 2)
                        save_image(result, os.path.join(self.sample_dir, str(epoch), f"SR_{step}.png"))

            torch.save(self.generator.state_dict(), os.path.join(self.checkpoint_dir, f"generator_{epoch}.pth"))
            torch.save(self.discriminator.state_dict(), os.path.join(self.checkpoint_dir, f"discriminator_{epoch}.pth"))

    def build_model(self):
        self.generator = ESRGAN(3, 3, 64, scale_factor=self.scale_factor).to(self.device)
        self.discriminator = Discriminator().to(self.device)
        self.load_model()

    def load_model(self):
        print(f"[*] Load model from {self.checkpoint_dir}")
        if not os.path.exists(self.checkpoint_dir):
            self.makedirs = os.makedirs(self.checkpoint_dir)

        if not os.listdir(self.checkpoint_dir):
            print(f"[!] No checkpoint in {self.checkpoint_dir}")
            return

        generator = glob(os.path.join(self.checkpoint_dir, f'generator_{self.epoch - 1}.pth'))
        discriminator = glob(os.path.join(self.checkpoint_dir, f'discriminator_{self.epoch - 1}.pth'))

        if not generator:
            print(f"[!] No checkpoint in epoch {self.epoch - 1}")
            return

        self.generator.load_state_dict(torch.load(generator[0]))
        self.discriminator.load_state_dict(torch.load(discriminator[0]))
Ejemplo n.º 24
0
class ChatSpaceTrainer:
    def __init__(
        self,
        config,
        model: ChatSpaceModel,
        vocab: Vocab,
        device: torch.device,
        train_corpus_path,
        eval_corpus_path=None,
        encoding="utf-8",
    ):
        self.config = config
        self.device = device
        self.model = model
        self.optimizer = Adam(self.model.parameters(),
                              lr=config["learning_rate"])
        self.criterion = nn.NLLLoss()
        self.vocab = vocab
        self.encoding = encoding

        self.train_corpus = DynamicCorpus(train_corpus_path,
                                          repeat=True,
                                          encoding=self.encoding)
        self.train_dataset = ChatSpaceDataset(config,
                                              self.train_corpus,
                                              self.vocab,
                                              with_random_space=True)

        if eval_corpus_path is not None:
            self.eval_corpus = DynamicCorpus(eval_corpus_path,
                                             encoding=self.encoding)
            self.eval_dataset = ChatSpaceDataset(self.config,
                                                 eval_corpus_path,
                                                 self.vocab,
                                                 with_random_space=True)

        self.global_epochs = 0
        self.global_steps = 0

    def eval(self, batch_size=64):
        self.model.eval()

        with torch.no_grad():
            eval_output = self.run_epoch(self.eval_dataset,
                                         batch_size=batch_size,
                                         is_train=False)

        self.model.train()
        return eval_output

    def train(self, epochs=10, batch_size=64):
        for epoch_id in range(epochs):
            self.run_epoch(
                self.train_dataset,
                batch_size=batch_size,
                epoch_id=epoch_id,
                is_train=True,
                log_freq=self.config["logging_step"],
            )
            self.save_checkpoint(
                f"outputs/checkpoints/checkpoint_ep{epoch_id}.cpt")
            self.save_model(f"outputs/models/chatspace_ep{epoch_id}.pt")
            self.save_model(f"outputs/jit_models/chatspace_ep{epoch_id}.pt",
                            as_jit=False)

    def run_epoch(self,
                  dataset,
                  batch_size=64,
                  epoch_id=0,
                  is_train=True,
                  log_freq=100):
        step_outputs, step_metrics, step_inputs = [], [], []
        collect_fn = (ChatSpaceDataset.train_collect_fn
                      if is_train else ChatSpaceDataset.eval_collect_fn)
        data_loader = DataLoader(dataset, batch_size, collate_fn=collect_fn)
        for step_num, batch in enumerate(data_loader):
            batch = {
                key: value.to(self.device)
                for key, value in batch.items()
            }
            output = self.step(step_num, batch)

            if is_train:
                self.update(output["loss"])

            if not is_train or step_num % log_freq == 0:
                batch = {
                    key: value.cpu().numpy()
                    for key, value in batch.items()
                }
                output = {
                    key: value.detach().cpu().numpy()
                    for key, value in output.items()
                }

                metric = self.step_metric(output["output"], batch,
                                          output["loss"])

                if is_train:
                    print(
                        f"EPOCH:{epoch_id}",
                        f"STEP:{step_num}/{len(data_loader)}",
                        [(key + ":" + "%.3f" % metric[key]) for key in metric],
                    )
                else:
                    step_outputs.append(output)
                    step_metrics.append(metric)
                    step_inputs.append(batch)

        if not is_train:
            return self.epoch_metric(step_inputs, step_outputs, step_metrics)

        if is_train:
            self.global_epochs += 1

    def epoch_metric(self, step_inputs, step_outputs, step_metrics):
        average_loss = np.mean([metric["loss"] for metric in step_metrics])

        epoch_inputs = [
            example for step_input in step_inputs
            for example in step_input["input"].tolist()
        ]
        epoch_outputs = [
            example for output in step_outputs
            for example in output["output"].argmax(axis=-1).tolist()
        ]
        epoch_labels = [
            example for step_input in step_inputs
            for example in step_input["label"].tolist()
        ]

        epoch_metric = calculated_metric(batch_input=epoch_inputs,
                                         batch_output=epoch_outputs,
                                         batch_label=epoch_labels)

        epoch_metric["loss"] = average_loss
        return epoch_metric

    def step_metric(self, output, batch, loss=None):
        metric = calculated_metric(
            batch_input=batch["input"].tolist(),
            batch_output=output.argmax(axis=-1).tolist(),
            batch_label=batch["label"].tolist(),
        )

        if loss is not None:
            metric["loss"] = loss
        return metric

    def step(self, step_num, batch, with_loss=True, is_train=True):
        output = self.model.forward(batch["input"], batch["length"])
        if is_train:
            self.global_steps += 1

        if not with_loss:
            return {"output": output}

        loss = self.criterion(output.transpose(1, 2), batch["label"])
        return {"loss": loss, "output": output}

    def update(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def save_model(self, path, as_jit=False):
        self.optimizer.zero_grad()
        params = [{
            "param": param,
            "require_grad": param.requires_grad
        } for param in self.model.parameters()]

        for param in params:
            param["param"].require_grad = False

        with torch.no_grad():
            if not as_jit:
                torch.save(self.model.state_dict(), path)
            else:
                self.model.cpu().eval()

                sample_texts = ["오늘 너무 재밌지 않았어?", "너랑 하루종일 놀아서 기분이 좋았어!"]
                dataset = ChatSpaceDataset(self.config,
                                           sample_texts,
                                           self.vocab,
                                           with_random_space=False)
                data_loader = DataLoader(dataset,
                                         batch_size=2,
                                         collate_fn=dataset.eval_collect_fn)

                for batch in data_loader:
                    model_input = (batch["input"].detach(),
                                   batch["length"].detach())
                    traced_model = torch.jit.trace(self.model, model_input)
                    torch.jit.save(traced_model, path)
                    break

                self.model.to(self.device).train()

        print(f"Model Saved on {path}{' as_jit' if as_jit else ''}")

        for param in params:
            if param["require_grad"]:
                param["param"].require_grad = True

    def save_checkpoint(self, path):
        torch.save(
            {
                "epoch": self.global_epochs,
                "steps": self.global_steps,
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
            },
            path,
        )

    def load_checkpoint(self, path):
        checkpoint = torch.load(path)
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.global_epochs = checkpoint["epoch"]
        self.global_steps = checkpoint["steps"]

    def load_model(self, model_path):
        self.model.load_state_dict(torch.load(model_path))
Ejemplo n.º 25
0
    theta_G = [G_PI, G_MU, G_A, G_D]

    optimizer = Adam(theta_G, lr=LR)

    for i in range(NUM_EPOCHS):
        train_loss = 0
        count = 0

        for x in train_loader:
            optimizer.zero_grad()

            loss = -get_log_likelihood(x.cuda(), *theta_G)

            loss.backward()

            optimizer.step()

            train_loss += loss.cpu().detach().item()
            count += 1

        train_loss /= count

        # Calculate validation loss
        val_loss = 0
        count = 0

        with torch.no_grad():
            for x in val_loader:
                loss = -get_log_likelihood(x.cuda(), *theta_G)

                val_loss += loss.cpu().detach().item()
Ejemplo n.º 26
0
def main():
    #command line argument parsing
    parser = ArgumentParser()
    parser.add_argument("-f", "--fold", default=0, type=int, help="enter fold")
    args = parser.parse_args()
    fold_index = args.fold

    #reading config file
    config = configparser.ConfigParser()
    config.sections()
    config.read('config.ini')

    #
    train_dataset = pd.read_csv(config["train_images"]["fold_directory"] +
                                "fold_train" + str(fold_index) + ".csv")
    val_dataset = pd.read_csv(config["train_images"]["fold_directory"] +
                              "fold_test" + str(fold_index) + ".csv")
    path = config["train_images"]["parts_dir"]
    create_black_image(path)

    train_dataset = generate_parts_df(train_dataset, path)
    val_dataset = generate_parts_df(val_dataset, path)

    train_transform, parts_transform, trunk_transform, val_transform, val_parts_transform, val_trunk_transform = create_transformation(
    )

    target = train_dataset['encoded_labels'].values
    train_images_path = config["train_images"]["train_image_path"]
    part_images_path = config["train_images"]["parts_dir"]
    train_dataset = TigersDataset(train_dataset, train_images_path,
                                  part_images_path, train_transform,
                                  parts_transform, trunk_transform)
    val_dataset = TigersDataset(val_dataset, train_images_path,
                                part_images_path, val_transform,
                                val_parts_transform, val_trunk_transform)

    n_classes = 4
    n_samples = 4

    batch_size = n_classes * n_samples
    balanced_batch_sampler_train = BalancedBatchSampler(
        train_dataset, n_classes, n_samples)
    balanced_batch_sampler_val = BalancedBatchSampler(val_dataset, 8, 2)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_sampler=balanced_batch_sampler_train)
    validation_loader = torch.utils.data.DataLoader(
        val_dataset, batch_sampler=balanced_batch_sampler_val)

    model = ClassificationNet()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), 0.0003)

    # scheduler_warmup is chained with schduler_steplr
    scheduler_steplr = StepLR(optimizer, step_size=80, gamma=0.5)
    #   scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=25, after_scheduler=scheduler_steplr)
    scheduler_warmup = LearningRateWarmUP(optimizer=optimizer,
                                          warmup_iteration=0,
                                          target_lr=0.0003,
                                          after_scheduler=scheduler_steplr)

    # this zero gradient update is needed to avoid a warning message, issue #8.
    optimizer.zero_grad()
    optimizer.step()

    n_epochs = 1
    print_every = 10
    margin = 0.3

    valid_loss_min = np.Inf
    val_loss = []
    train_loss = []
    train_acc = []
    val_acc = []
    total_step = len(train_loader)
    for epoch in range(1, n_epochs + 1):
        running_loss = 0.0
        # scheduler.step(epoch)
        correct = 0
        total = 0
        print(f'Epoch {epoch}\n')
        scheduler_warmup.step(epoch)
        for batch_idx, (full_image, part1, part2, part3, part4, part5, part6,
                        body, target_) in enumerate(train_loader):

            full_image, part1, part2, part3, part4, part5, part6, body, target_ = full_image.to(
                device), part1.to(device), part2.to(device), part3.to(
                    device), part4.to(device), part5.to(device), part6.to(
                        device), body.to(device), target_.to(device)  # on GPU
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            #outputs = model(data_.cuda())
            output1, output2, output3, Zft, Zfl = model(
                full_image, part1, part2, part3, part4, part5, part6, body)

            global_loss = criterion(output1, target_)
            global_trunk_loss = criterion(output2, target_)
            global_part_loss = criterion(output3, target_)
            global_trunk_triplet_loss = batch_hard_triplet_loss(target_,
                                                                Zft,
                                                                margin=margin,
                                                                device=device)
            global_part_triplet_loss = batch_hard_triplet_loss(target_,
                                                               Zfl,
                                                               margin=margin,
                                                               device=device)

            loss = global_loss + 1.5 * global_trunk_loss + 1.5 * global_part_loss + 2 * global_trunk_triplet_loss + 2 * global_part_triplet_loss

            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            _, pred = torch.max(output1, dim=1)
            correct += torch.sum(pred == target_).item()
            total += target_.size(0)
            if (batch_idx) % print_every == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Acc: {:.4f}'.
                      format(epoch, n_epochs, batch_idx, total_step,
                             loss.item(), 100 * correct / total))

        train_acc.append(100 * correct / total)
        train_loss.append(running_loss / total_step)
        print(
            f'\ntrain loss: {np.mean(train_loss):.4f},\ntrain Acc: {np.mean(train_acc):.4f}'
        )
        batch_loss = 0
        total_t = 0
        correct_t = 0
        if epoch % 5 == 0:
            print('Evaluation')
            with torch.no_grad():
                model.eval()
                for (full_image, part1, part2, part3, part4, part5, part6,
                     body, target_t) in (validation_loader):

                    full_image, part1, part2, part3, part4, part5, part6, body, target_t = full_image.to(
                        device), part1.to(device), part2.to(device), part3.to(
                            device), part4.to(device), part5.to(
                                device), part6.to(device), body.to(
                                    device), target_t.to(device)  # on GPU
                    output1, output2, output3, Zft, Zfl = model(
                        full_image, part1, part2, part3, part4, part5, part6,
                        body)

                    global_loss_t = criterion(output1, target_t)
                    global_trunk_loss_t = criterion(output2, target_t)
                    global_part_loss_t = criterion(output3, target_t)
                    global_trunk_triplet_loss_t = batch_hard_triplet_loss(
                        target_t, Zft, margin=margin, device=device)
                    global_part_triplet_loss_t = batch_hard_triplet_loss(
                        target_t, Zfl, margin=margin, device=device)

                    loss_t = global_loss_t + 1.5 * global_trunk_loss_t + 1.5 * global_part_loss_t + 2 * global_trunk_triplet_loss_t + 2 * global_part_triplet_loss_t

                    batch_loss += loss_t.item()
                    _, pred_t = torch.max(output1, dim=1)
                    correct_t += torch.sum(pred_t == target_t).item()
                    total_t += target_t.size(0)
                val_acc.append(100 * correct_t / total_t)
                val_loss.append(batch_loss / len(validation_loader))
                network_learned = batch_loss < valid_loss_min
                print(
                    f'validation loss: {np.mean(val_loss):.4f}\n val acc: {np.mean(val_acc):.4f}\n'
                )
                # Saving the best weight
                if network_learned:
                    valid_loss_min = batch_loss
                    torch.save(
                        model.state_dict(),
                        config["train_images"]["train_weights"] +
                        'ppbm_model_fold' + str(fold_index) + '.pt')
                    print('Detected network improvement, saving current model')
        model.train()
    torch.save(
        model.state_dict(), config["train_images"]["train_weights"] +
        'ppbm_model_last_fold' + str(fold_index) + '.pt')
Ejemplo n.º 27
0
class TD3(Agent):
    def __init__(self,
                 algo_params,
                 env,
                 transition_tuple=None,
                 path=None,
                 seed=-1):
        # environment
        self.env = env
        self.env.seed(seed)
        obs = self.env.reset()
        algo_params.update({
            'state_dim': obs.shape[0],
            'action_dim': self.env.action_space.shape[0],
            'action_max': self.env.action_space.high,
            'action_scaling': self.env.action_space.high[0],
            'init_input_means': None,
            'init_input_vars': None
        })
        # training args
        self.training_episodes = algo_params['training_episodes']
        self.testing_gap = algo_params['testing_gap']
        self.testing_episodes = algo_params['testing_episodes']
        self.saving_gap = algo_params['saving_gap']

        super(TD3, self).__init__(algo_params,
                                  transition_tuple=transition_tuple,
                                  goal_conditioned=False,
                                  path=path,
                                  seed=seed)
        # torch
        self.network_dict.update({
            'actor':
            Actor(self.state_dim,
                  self.action_dim,
                  action_scaling=self.action_scaling).to(self.device),
            'actor_target':
            Actor(self.state_dim,
                  self.action_dim,
                  action_scaling=self.action_scaling).to(self.device),
            'critic_1':
            Critic(self.state_dim + self.action_dim, 1).to(self.device),
            'critic_1_target':
            Critic(self.state_dim + self.action_dim, 1).to(self.device),
            'critic_2':
            Critic(self.state_dim + self.action_dim, 1).to(self.device),
            'critic_2_target':
            Critic(self.state_dim + self.action_dim, 1).to(self.device)
        })
        self.network_keys_to_save = ['actor_target', 'critic_1_target']
        self.actor_optimizer = Adam(self.network_dict['actor'].parameters(),
                                    lr=self.actor_learning_rate)
        self._soft_update(self.network_dict['actor'],
                          self.network_dict['actor_target'],
                          tau=1)
        self.critic_1_optimizer = Adam(
            self.network_dict['critic_1'].parameters(),
            lr=self.critic_learning_rate)
        self._soft_update(self.network_dict['critic_1'],
                          self.network_dict['critic_1_target'],
                          tau=1)
        self.critic_2_optimizer = Adam(
            self.network_dict['critic_2'].parameters(),
            lr=self.critic_learning_rate)
        self._soft_update(self.network_dict['critic_2'],
                          self.network_dict['critic_2_target'],
                          tau=1)
        # behavioural policy args (exploration)
        self.exploration_strategy = GaussianNoise(self.action_dim,
                                                  self.action_max,
                                                  mu=0,
                                                  sigma=0.1)
        # training args
        self.target_noise = algo_params['target_noise']
        self.noise_clip = algo_params['noise_clip']
        self.warmup_step = algo_params['warmup_step']
        self.actor_update_interval = algo_params['actor_update_interval']
        # statistic dict
        self.statistic_dict.update({
            'episode_return': [],
            'episode_test_return': []
        })

    def run(self, test=False, render=False, load_network_ep=None, sleep=0):
        if test:
            num_episode = self.testing_episodes
            if load_network_ep is not None:
                print("Loading network parameters...")
                self._load_network(ep=load_network_ep)
            print("Start testing...")
        else:
            num_episode = self.training_episodes
            print("Start training...")

        for ep in range(num_episode):
            ep_return = self._interact(render, test, sleep=sleep)
            self.statistic_dict['episode_return'].append(ep_return)
            print("Episode %i" % ep, "return %0.1f" % ep_return)

            if (ep % self.testing_gap == 0) and (ep != 0) and (not test):
                ep_test_return = []
                for test_ep in range(self.testing_episodes):
                    ep_test_return.append(self._interact(render, test=True))
                self.statistic_dict['episode_test_return'].append(
                    sum(ep_test_return) / self.testing_episodes)
                print(
                    "Episode %i" % ep, "test return %0.1f" %
                    (sum(ep_test_return) / self.testing_episodes))

            if (ep % self.saving_gap == 0) and (ep != 0) and (not test):
                self._save_network(ep=ep)

        if not test:
            print("Finished training")
            print("Saving statistics...")
            self._plot_statistics(save_to_file=True)
        else:
            print("Finished testing")

    def _interact(self, render=False, test=False, sleep=0):
        done = False
        obs = self.env.reset()
        ep_return = 0
        # start a new episode
        while not done:
            if render:
                self.env.render()
            if self.env_step_count < self.warmup_step:
                action = self.env.action_space.sample()
            else:
                action = self._select_action(obs, test=test)
            new_obs, reward, done, info = self.env.step(action)
            time.sleep(sleep)
            ep_return += reward
            if not test:
                self._remember(obs, action, new_obs, reward, 1 - int(done))
                if self.observation_normalization:
                    self.normalizer.store_history(new_obs)
                    self.normalizer.update_mean()
                if (self.env_step_count % self.update_interval
                        == 0) and (self.env_step_count > self.warmup_step):
                    self._learn()
            obs = new_obs
            self.env_step_count += 1
        return ep_return

    def _select_action(self, obs, test=False):
        obs = self.normalizer(obs)
        with T.no_grad():
            inputs = T.as_tensor(obs, dtype=T.float, device=self.device)
            action = self.network_dict['actor_target'](
                inputs).detach().cpu().numpy()
        if test:
            # evaluate
            return np.clip(action, -self.action_max, self.action_max)
        else:
            # explore
            return self.exploration_strategy(action)

    def _learn(self, steps=None):
        if len(self.buffer) < self.batch_size:
            return
        if steps is None:
            steps = self.optimizer_steps

        for i in range(steps):
            if self.prioritised:
                batch, weights, inds = self.buffer.sample(self.batch_size)
                weights = T.as_tensor(weights, device=self.device).view(
                    self.batch_size, 1)
            else:
                batch = self.buffer.sample(self.batch_size)
                weights = T.ones(size=(self.batch_size, 1), device=self.device)
                inds = None

            actor_inputs = self.normalizer(batch.state)
            actor_inputs = T.as_tensor(actor_inputs,
                                       dtype=T.float32,
                                       device=self.device)
            actions = T.as_tensor(batch.action,
                                  dtype=T.float32,
                                  device=self.device)
            critic_inputs = T.cat((actor_inputs, actions), dim=1)
            actor_inputs_ = self.normalizer(batch.next_state)
            actor_inputs_ = T.as_tensor(actor_inputs_,
                                        dtype=T.float32,
                                        device=self.device)
            rewards = T.as_tensor(batch.reward,
                                  dtype=T.float32,
                                  device=self.device).unsqueeze(1)
            done = T.as_tensor(batch.done, dtype=T.float32,
                               device=self.device).unsqueeze(1)

            if self.discard_time_limit:
                done = done * 0 + 1

            with T.no_grad():
                actions_ = self.network_dict['actor_target'](actor_inputs_)
                # add noise
                noise = (T.randn_like(actions_, device=self.device) *
                         self.target_noise)
                actions_ += noise.clamp(-self.noise_clip, self.noise_clip)
                actions_ = actions_.clamp(-self.action_max[0],
                                          self.action_max[0])
                critic_inputs_ = T.cat((actor_inputs_, actions_), dim=1)
                value_1_ = self.network_dict['critic_1_target'](critic_inputs_)
                value_2_ = self.network_dict['critic_2_target'](critic_inputs_)
                value_ = T.min(value_1_, value_2_)
                value_target = rewards + done * self.gamma * value_

            self.critic_1_optimizer.zero_grad()
            value_estimate_1 = self.network_dict['critic_1'](critic_inputs)
            critic_loss_1 = F.mse_loss(value_estimate_1,
                                       value_target.detach(),
                                       reduction='none')
            (critic_loss_1 * weights).mean().backward()
            self.critic_1_optimizer.step()

            if self.prioritised:
                assert inds is not None
                self.buffer.update_priority(
                    inds, np.abs(critic_loss_1.cpu().detach().numpy()))

            self.critic_2_optimizer.zero_grad()
            value_estimate_2 = self.network_dict['critic_2'](critic_inputs)
            critic_loss_2 = F.mse_loss(value_estimate_2,
                                       value_target.detach(),
                                       reduction='none')
            (critic_loss_2 * weights).mean().backward()
            self.critic_2_optimizer.step()

            self.statistic_dict['critic_loss'].append(
                critic_loss_1.detach().mean())

            if self.optim_step_count % self.actor_update_interval == 0:
                self.actor_optimizer.zero_grad()
                new_actions = self.network_dict['actor'](actor_inputs)
                critic_eval_inputs = T.cat((actor_inputs, new_actions), dim=1)
                actor_loss = -self.network_dict['critic_1'](
                    critic_eval_inputs).mean()
                actor_loss.backward()
                self.actor_optimizer.step()

                self._soft_update(self.network_dict['actor'],
                                  self.network_dict['actor_target'])
                self._soft_update(self.network_dict['critic_1'],
                                  self.network_dict['critic_1_target'])
                self._soft_update(self.network_dict['critic_2'],
                                  self.network_dict['critic_2_target'])

                self.statistic_dict['actor_loss'].append(
                    actor_loss.detach().mean())

            self.optim_step_count += 1
Ejemplo n.º 28
0
            y = upsample(y.view(-1, 3, 64, 64), size=256)
            x = upsample(x.view(-1, 3, 64, 64), size=256)

            # x - mfa image
            # y - original image

            D_result = D(x, y).squeeze()
            D_real_loss = BCE_loss(D_result, torch.ones(D_result.size()).cuda())

            G_result = G(x)
            D_result = D(x, G_result).squeeze()
            D_fake_loss = BCE_loss(D_result, torch.zeros(D_result.size()).cuda())

            D_train_loss = (D_real_loss + D_fake_loss) * 0.5
            D_train_loss.backward()
            D_optimizer.step()

            G.zero_grad()

            G_result = G(x)
            D_result = D(x, G_result).squeeze()

            G_train_loss = BCE_loss(D_result, torch.ones(D_result.size()).cuda()) + 100 * L1_loss(G_result, y)
            G_train_loss.backward()
            G_optimizer.step()

            disc_loss += D_train_loss
            gen_loss += G_train_loss
            count += 1

            if count % 1000 == 0:
Ejemplo n.º 29
0
def main(args):
    args_dict = vars(args)

    print(args_dict)

    if args_dict['saved_model'] is not None and os.path.exists(
            args_dict['saved_model']):
        device, model, optimizer, saved_args_dict = models.load_model(
            args_dict['saved_model'])
        args_dict = saved_args_dict
    else:
        model = models.HourglassNetwork(num_channels=args.channels,
                                        num_stacks=args.stacks,
                                        num_classes=args.joints,
                                        input_shape=(args.input_dim,
                                                     args.input_dim, 3))
        print(torch.cuda.device_count(), "GPUs available.")

        device = torch.device(args_dict['device'])
        model = model.to(device).double()
        optimizer = Adam(model.parameters(), lr=args.lr)

    mpii_train = datasets.MPII_dataset(
        dataset_type='train',
        images_dir=args_dict['images_dir'],
        annots_json_filename=args_dict['annots_path'],
        mean_path=args_dict['mean_path'],
        std_path=args_dict['std_path'],
        input_shape=args_dict['input_dim'],
        output_shape=args_dict['output_dim'])

    mpii_valid = datasets.MPII_dataset(
        dataset_type='valid',
        images_dir=args_dict['images_dir'],
        annots_json_filename=args_dict['annots_path'],
        mean_path=args_dict['mean_path'],
        std_path=args_dict['std_path'],
        input_shape=args_dict['input_dim'],
        output_shape=args_dict['output_dim'])

    train_dataloader = DataLoader(dataset=mpii_train,
                                  batch_size=args_dict['batch_size'],
                                  shuffle=True,
                                  num_workers=0)
    valid_dataloader = DataLoader(dataset=mpii_valid,
                                  batch_size=args_dict['batch_size'],
                                  shuffle=False,
                                  num_workers=0)

    criterion = losses.JointsMSELoss().to(device)
    logger = loggers.CSVLogger(args_dict['logger_csv_path'])

    for epoch in range(args_dict.get('epoch_to_start', 0),
                       args_dict['epochs']):

        model.train()

        if epoch == 60 - 1 or epoch == 90 - 1:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
                args_dict['lr'] = param_group['lr']
                print("Learning rate changed to", args_dict['lr'])
        else:
            print("Learning rate stayed the same", args_dict['lr'])

        for i, (input_batch, output_batch,
                meta_batch) in enumerate(train_dataloader):

            # TODO adjust learning rate

            x = input_batch.to(device)
            y_kappa = output_batch.to(device, non_blocking=True)
            weights = meta_batch['label_weights'].to(device, non_blocking=True)

            y = model(x)

            loss = 0
            for _y in y:
                loss += criterion(_y, y_kappa, weights)

            joint_distances, accuracy_per_joint, average_accuracy = eval.output_accuracy(
                y=y[-1], y_kappa=y_kappa, threshold=args_dict['threshold'])

            print(
                'TRAIN: Epoch=[{}/{}], Step=[{}/{}], Loss={:.8f}, Avg_Acc: {:.5f}'
                .format(epoch + 1, args_dict['epochs'], i + 1,
                        len(train_dataloader), loss.item(), average_accuracy))

            logger.log(epoch + 1, args_dict['epochs'], i,
                       len(train_dataloader), 'train', loss,
                       accuracy_per_joint, average_accuracy)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()

        with torch.no_grad():

            for i, (input_batch, output_batch,
                    meta_batch) in enumerate(valid_dataloader):
                x = input_batch.to(device)
                y_kappa = output_batch.to(device, non_blocking=True)
                weights = meta_batch['label_weights'].to(device,
                                                         non_blocking=True)

                y = model(x)

                loss = 0
                for _y in y:
                    loss += criterion(_y, y_kappa, weights)

                joint_distances, accuracy_per_joint, average_accuracy = eval.output_accuracy(
                    y=y[-1], y_kappa=y_kappa, threshold=args_dict['threshold'])

                print(
                    'VALID: Epoch=[{}/{}], Step=[{}/{}], Loss={:.8f}, Avg_Acc: {:.5f}'
                    .format(epoch + 1, args_dict['epochs'], i + 1,
                            len(valid_dataloader), loss.item(),
                            average_accuracy))

                logger.log(epoch + 1, args_dict['epochs'], i,
                           len(train_dataloader), 'valid', loss,
                           accuracy_per_joint, average_accuracy)

        args_dict['epoch_to_start'] = epoch + 1

        if epoch % 5 == 0:
            models.save_model(model=model,
                              optimizer=optimizer,
                              args_dict=args_dict,
                              save_path=args_dict['saved_path'] + str(epoch))

        models.save_model(model=model,
                          optimizer=optimizer,
                          args_dict=args_dict,
                          save_path=args_dict['saved_model'])
Ejemplo n.º 30
0
class SAC:
    def __init__(self, env_name, n_states, n_actions, memory_size, batch_size,
                 gamma, alpha, lr, action_bounds, reward_scale):
        self.env_name = env_name
        self.n_states = n_states
        self.n_actions = n_actions
        self.memory_size = memory_size
        self.batch_size = batch_size
        self.gamma = gamma
        self.alpha = alpha
        self.lr = lr
        self.action_bounds = action_bounds
        self.reward_scale = reward_scale
        self.memory = Memory(memory_size=self.memory_size)

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.policy_network = PolicyNetwork(
            n_states=self.n_states,
            n_actions=self.n_actions,
            action_bounds=self.action_bounds).to(self.device)
        self.q_value_network1 = QvalueNetwork(n_states=self.n_states,
                                              n_actions=self.n_actions).to(
                                                  self.device)
        self.q_value_network2 = QvalueNetwork(n_states=self.n_states,
                                              n_actions=self.n_actions).to(
                                                  self.device)
        self.value_network = ValueNetwork(n_states=self.n_states).to(
            self.device)
        self.value_target_network = ValueNetwork(n_states=self.n_states).to(
            self.device)
        self.value_target_network.load_state_dict(
            self.value_network.state_dict())
        self.value_target_network.eval()

        self.value_loss = torch.nn.MSELoss()
        self.q_value_loss = torch.nn.MSELoss()

        self.value_opt = Adam(self.value_network.parameters(), lr=self.lr)
        self.q_value1_opt = Adam(self.q_value_network1.parameters(),
                                 lr=self.lr)
        self.q_value2_opt = Adam(self.q_value_network2.parameters(),
                                 lr=self.lr)
        self.policy_opt = Adam(self.policy_network.parameters(), lr=self.lr)

    def store(self, state, reward, done, action, next_state):
        state = from_numpy(state).float().to("cpu")
        reward = torch.Tensor([reward]).to("cpu")
        done = torch.Tensor([done]).to("cpu")
        action = torch.Tensor([action]).to("cpu")
        next_state = from_numpy(next_state).float().to("cpu")
        self.memory.add(state, reward, done, action, next_state)

    def unpack(self, batch):
        batch = Transition(*zip(*batch))

        states = torch.cat(batch.state).view(self.batch_size,
                                             self.n_states).to(self.device)
        rewards = torch.cat(batch.reward).view(self.batch_size,
                                               1).to(self.device)
        dones = torch.cat(batch.done).view(self.batch_size, 1).to(self.device)
        actions = torch.cat(batch.action).view(-1,
                                               self.n_actions).to(self.device)
        next_states = torch.cat(batch.next_state).view(
            self.batch_size, self.n_states).to(self.device)

        return states, rewards, dones, actions, next_states

    def train(self):
        if len(self.memory) < self.batch_size:
            return 0, 0, 0
        else:
            batch = self.memory.sample(self.batch_size)
            states, rewards, dones, actions, next_states = self.unpack(batch)

            # Calculating the value target
            reparam_actions, log_probs = self.policy_network.sample_or_likelihood(
                states)
            q1 = self.q_value_network1(states, reparam_actions)
            q2 = self.q_value_network2(states, reparam_actions)
            q = torch.min(q1, q2)
            target_value = q.detach() - self.alpha * log_probs.detach()

            value = self.value_network(states)
            value_loss = self.value_loss(value, target_value)

            # Calculating the Q-Value target
            with torch.no_grad():
                target_q = self.reward_scale * rewards + \
                           self.gamma * self.value_target_network(next_states) * (1 - dones)
            q1 = self.q_value_network1(states, actions)
            q2 = self.q_value_network2(states, actions)
            q1_loss = self.q_value_loss(q1, target_q)
            q2_loss = self.q_value_loss(q2, target_q)

            policy_loss = (self.alpha * log_probs - q).mean()

            self.policy_opt.zero_grad()
            policy_loss.backward()
            self.policy_opt.step()

            self.value_opt.zero_grad()
            value_loss.backward()
            self.value_opt.step()

            self.q_value1_opt.zero_grad()
            q1_loss.backward()
            self.q_value1_opt.step()

            self.q_value2_opt.zero_grad()
            q2_loss.backward()
            self.q_value2_opt.step()

            self.soft_update_target_network(self.value_network,
                                            self.value_target_network)

            return value_loss.item(), 0.5 * (
                q1_loss + q2_loss).item(), policy_loss.item()

    def choose_action(self, states):
        states = np.expand_dims(states, axis=0)
        states = from_numpy(states).float().to(self.device)
        action, _ = self.policy_network.sample_or_likelihood(states)
        return action.detach().cpu().numpy()[0]

    @staticmethod
    def soft_update_target_network(local_network, target_network, tau=0.005):
        for target_param, local_param in zip(target_network.parameters(),
                                             local_network.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1 - tau) * target_param.data)

    def save_weights(self):
        torch.save(self.policy_network.state_dict(),
                   self.env_name + "_weights.pth")

    def load_weights(self):
        self.policy_network.load_state_dict(
            torch.load(self.env_name + "_weights.pth"))

    def set_to_eval_mode(self):
        self.policy_network.eval()