예제 #1
0
def cleit_train_step(model, reference_encoder, transmitter, s_batch, t_batch, device, optimizer, alpha, history, scheduler=None):
    model.zero_grad()
    model.train()
    transmitter.zero_grad()
    reference_encoder.zero_grad()
    transmitter.train()
    reference_encoder.eval()

    s_x = s_batch[0].to(device)
    s_y = s_batch[1].to(device)

    t_x = t_batch[0].to(device)
    t_y = t_batch[1].to(device)

    x_m_code = transmitter(model.encoder(t_x))
    x_g_code = reference_encoder(s_x)

    code_loss = contrastive_loss(y_true=x_g_code, y_pred=x_m_code, device=device)
    loss = masked_simse(preds=model(t_x), labels=t_y) + alpha * code_loss

    optimizer.zero_grad()

    loss.backward()
    optimizer.step()
    if scheduler is not None:
        scheduler.step()

    history['loss'].append(loss.cpu().detach().item())
    history['code_loss'].append(code_loss.cpu().detach().item())

    return history
예제 #2
0
def cleit_train_step(ae, reference_encoder, batch, device, optimizer, history, scheduler=None):
    ae.zero_grad()
    reference_encoder.zero_grad()
    ae.train()
    reference_encoder.eval()

    x_m = batch[0].to(device)
    x_g = batch[1].to(device)
    loss_dict = ae.loss_function(*ae(x_m))
    optimizer.zero_grad()

    x_m_code = ae.encoder(x_m)
    x_g_code = reference_encoder(x_g)

    code_loss = contrastive_loss(y_true=x_g_code, y_pred=x_m_code, device=device)
    loss = loss_dict['loss'] + code_loss
    optimizer.zero_grad()

    loss.backward()
    optimizer.step()
    if scheduler is not None:
        scheduler.step()

    for k, v in loss_dict.items():
        history[k].append(v)
    history['code_loss'].append(code_loss.cpu().detach().item())
    return history
예제 #3
0
def cleit_train_step(ae, reference_encoder, transmitter, batch, device, optimizer, history, scheduler=None):
    ae.zero_grad()
    transmitter.zero_grad()
    reference_encoder.zero_grad()
    ae.train()
    transmitter.train()
    reference_encoder.eval()

    x_m = batch[0].to(device)
    x_g = batch[1].to(device)
    loss_dict = ae.loss_function(*ae(x_m))
    optimizer.zero_grad()

    x_m_code = transmitter(ae.encoder(x_m))
    x_g_code = reference_encoder(x_g)

    code_loss = contrastive_loss(y_true=x_g_code, y_pred=x_m_code, device=device)
    loss = loss_dict['loss'] + code_loss

    optimizer.zero_grad()

    loss.backward()
    # if torch.isnan(x_m_code).any():
    #     print(loss_dict)
    #     print(code_loss)
    #     print(torch.sum(torch.isnan(x_m_code),dim=1))
    #     print(torch.isnan(x_m).any())
    #     print(torch.isnan(list(ae.encoder.modules())[0][-1].weight).any())
    #     print(list(ae.encoder.modules())[0][-1].weight.grad)
    #     print("="*20)
    # cleit_params = [
    #     ae.parameters(),
    #     transmitter.parameters()
    # ]
    # torch.nn.utils.clip_grad_norm_(chain(*cleit_params), 0.1)

    optimizer.step()
    #torch.autograd.set_detect_anomaly(True)
    if scheduler is not None:
        scheduler.step()

    for k, v in loss_dict.items():
        history[k].append(v)
    history['code_loss'].append(code_loss.cpu().detach().item())
    return history
예제 #4
0
def ceae_train_step(p_ae,
                    t_ae,
                    transmitter,
                    batch,
                    device,
                    optimizer,
                    history,
                    scheduler=None):
    p_ae.zero_grad()
    t_ae.zero_grad()
    transmitter.zero_grad()
    p_ae.train()
    t_ae.train()
    transmitter.train()

    x_p = batch[0].to(device)
    x_t = batch[1].to(device)
    p_loss_dict = p_ae.loss_function(*p_ae(x_p))
    t_loss_dict = t_ae.loss_function(*t_ae(x_t))

    optimizer.zero_grad()
    x_t_code = t_ae.encode(x_t)  #
    x_p_code = p_ae.encode(x_p)

    code_loss = contrastive_loss(y_true=transmitter(x_t_code),
                                 y_pred=x_p_code,
                                 device=device)
    # loss = loss_dict['loss']
    # loss = code_loss
    loss = p_loss_dict['loss'] + t_loss_dict['loss'] + code_loss
    optimizer.zero_grad()

    loss.backward()
    optimizer.step()
    if scheduler is not None:
        scheduler.step()

    for k, v in p_loss_dict.items():
        history[f'p_{k}'].append(v)
    for k, v in t_loss_dict.items():
        history[f't_{k}'].append(v)
    history['code_loss'].append(code_loss.cpu().detach().item())
    return history
예제 #5
0
def ceae_train_step(ae,
                    transmitter,
                    batch,
                    device,
                    optimizer,
                    history,
                    scheduler=None):
    ae.zero_grad()
    transmitter.zero_grad()
    ae.train()
    transmitter.train()

    x_p = batch[0].to(device)
    x_t = batch[1].to(device)
    loss_dict = ae.loss_function(*ae(x_p))
    optimizer.zero_grad()

    x_t_code = x_t  #
    x_p_code = ae.encode(x_p)

    code_loss = contrastive_loss(y_true=x_t_code,
                                 y_pred=transmitter(x_p_code),
                                 device=device)
    # loss = loss_dict['loss']
    loss = loss_dict['loss'] + code_loss
    # loss = code_loss
    optimizer.zero_grad()

    loss.backward()
    optimizer.step()
    if scheduler is not None:
        scheduler.step()

    for k, v in loss_dict.items():
        history[k].append(v)
    history['code_loss'].append(code_loss.cpu().detach().item())
    return history