def dembed_loss(real_one_hot, gen_one_hot, show_partial=False):
    deemb_loss = MSELoss()(real_one_hot, gen_one_hot)

    if show_partial:
        print('\tDEEMB deemb_loss ' + str(deemb_loss.item()))

    return deemb_loss
def ae_main_loss(input, output, lat_mu, lat_var, show_partial=False) -> Variable:
    mse_loss = MSELoss()(input, output)
    kld_loss = torch.mean(-0.5 * torch.sum(1 + lat_var - lat_mu ** 2 - lat_var.exp(), dim=1), dim=0)

    if show_partial:
        print('\tMAIN mse_loss ' + str(mse_loss.item()))
        print('\tMAIN kld_loss ' + str(kld_loss.item()))

    return mse_loss + kld_loss*0.00
    def one_batch_train(self, Xs, Y):
        Ws = Xs['word_vectors']

        self.optimizer.zero_grad()
        outputs = self.model(Ws)
        mse_loss = MSELoss(outputs, Y)

        mse_loss.backward()
        self.optimizer.step()
        self.save_checkpoint('../../../data/nn/lstm.pth.tar')
        return self.model.data.cpu().numpy(), mse_loss.item()
def ae_prtcl_loss(input_x, output_x, show_partial=False) -> Variable:
    mse_loss = MSELoss()(input_x, output_x)
    ''' # VAE
    kld_loss = torch.mean(-0.5 * torch.sum(1 + lat_var - lat_mu ** 2 - lat_var.exp(), dim=1), dim=0)
    '''

    if show_partial:
        print('\tPRTCL mse_loss ' + str(mse_loss.item()))
        # print('\tPRTCL kld_loss ' + str(kld_loss.item()))

    return mse_loss # + kld_loss * 0.001
Beispiel #5
0
def train_model():
    model.train()

    total_loss = 0
    pbar = tqdm(train_loader)
    for data in pbar:
        if not args.multi_gpu: data = data.to(device)
        optimizer.zero_grad()
        #print(data)
        #print(args)
        if args.use_orig_graph:
            x = data.x if args.use_feature else None
            edge_weight = data.edge_weight if args.use_edge_weight else None
            node_id = data.node_id if emb else None
            out = model(data.z, data.edge_index, data.batch, x, edge_weight,
                        node_id).view(-1)
        else:
            out = model(data).view(-1)
        if args.multi_gpu:
            y = torch.cat([d.y.to(torch.float) for d in data]).to(out.device)
        else:
            y = data.y.to(torch.float)

        if args.neg_edge_percent != 100:
            y_neg = y[y == 0]
            out_neg = out[y == 0]
            y_pos = y[y != 0]
            out_pos = out[y != 0]

            num_neg = int(args.neg_edge_percent / 100 * len(out_neg))
            out_neg, y_neg = out_neg[:num_neg], y_neg[:num_neg]

            y = torch.cat([y_pos, y_neg])
            out = torch.cat([out_pos, out_neg])

        loss = MSELoss()(out, y)
        loss = torch.sqrt(loss)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * len(data)

    return total_loss / len(train_dataset)
Beispiel #6
0
def main(screen):
    forces = [hor_impulse]
    ground_truth_mass = torch.tensor([TOTAL_MASS], dtype=DTYPE)
    world, chain = make_world(forces, ground_truth_mass, num_links=NUM_LINKS)

    rec = None
    # rec = Recorder(DT, screen)
    ground_truth_pos = positions_run_world(world, run_time=10, screen=screen, recorder=rec)
    ground_truth_pos = [p.data for p in ground_truth_pos]
    ground_truth_pos = torch.cat(ground_truth_pos)

    learning_rate = 0.5
    max_iter = 100

    next_mass = torch.rand_like(ground_truth_mass, requires_grad=True)
    print('\rInitial mass:', next_mass.item())
    print('-----')

    optim = torch.optim.RMSprop([next_mass], lr=learning_rate)
    loss_hist = []
    mass_hist = [next_mass.item()]
    last_loss = 1e10
    for i in range(max_iter):
        if i % 1 == 0:
            world, chain = make_world(forces, next_mass.clone().detach(), num_links=NUM_LINKS)
            run_world(world, run_time=10, print_time=False, screen=None, recorder=None)

        world, chain = make_world(forces, next_mass, num_links=NUM_LINKS)
        positions = positions_run_world(world, run_time=10, screen=None)
        positions = torch.cat(positions)
        positions = positions[:len(ground_truth_pos)]
        clipped_ground_truth_pos = ground_truth_pos[:len(positions)]

        optim.zero_grad()
        loss = MSELoss()(positions, clipped_ground_truth_pos)
        loss.backward()

        optim.step()

        print('Iteration: {} / {}'.format(i+1, max_iter))
        print('Loss:', loss.item())
        print('Gradient:', next_mass.grad.item())
        print('Next mass:', next_mass.item())
        print('-----')
        if abs((last_loss - loss).item()) < STOP_DIFF:
            print('Loss changed by less than {} between iterations, stopping training.'
                  .format(STOP_DIFF))
            break
        last_loss = loss
        loss_hist.append(loss.item())
        mass_hist.append(next_mass.item())

    world = make_world(forces, next_mass, num_links=NUM_LINKS)[0]
    rec = None
    positions = positions_run_world(world, run_time=10, screen=screen, recorder=rec)
    positions = torch.cat(positions)
    positions = positions[:len(ground_truth_pos)]
    clipped_ground_truth_pos = ground_truth_pos[:len(positions)]
    loss = MSELoss()(positions, clipped_ground_truth_pos)
    print('Final loss:', loss.item())
    print('Final mass:', next_mass.item())

    plot(loss_hist)
    plot(mass_hist)
Beispiel #7
0
def main(screen):
    dateTimeObj = datetime.now()

    timestampStr = dateTimeObj.strftime("%d-%b-%Y(%H:%M)")
    print('Current Timestamp : ', timestampStr)

    if not os.path.isdir(timestampStr):
        os.mkdir(timestampStr)

    #if torch.cuda.is_available():
    #    dev = "cuda:0"
    #else:
    #    dev = "cpu"

    forces = []

    rec = None
    #rec = Recorder(DT, screen)

    #plot(zhist)

    learning_rate = 0.5
    max_iter = 100

    plt.subplots(21)
    plt.subplot(212)
    plt.gca().invert_yaxis()

    optim = torch.optim.RMSprop([utT], lr=learning_rate)

    last_loss = 1e10
    lossHist = []

    for i in range(1, 20000):
        world, chain = make_world(forces, ctT, utT)

        hist = positions_run_world(world,
                                   run_time=runtime,
                                   screen=screen,
                                   recorder=rec)
        xy = hist[0]
        vel = hist[1]
        control = hist[2]
        xy = torch.cat(xy).to(device=Defaults.DEVICE)
        x1 = xy[0::2]
        y1 = xy[1::2]
        control = torch.cat(control).to(device=Defaults.DEVICE)
        optim.zero_grad()

        targetxy = []
        targetControl = []
        j = 0
        while (j < xy.size()[0]):
            targetxy.append(500)
            targetxy.append(240)

            targetControl.append(0)
            targetControl.append(0)

            j = j + 2

        tt = torch.tensor(targetxy,
                          requires_grad=True,
                          dtype=DTYPE,
                          device=Defaults.DEVICE).t()
        tc = torch.tensor(targetControl,
                          requires_grad=True,
                          dtype=DTYPE,
                          device=Defaults.DEVICE).t()
        #loss = MSELoss()(zhist, 150*torch.tensor(np.ones(zhist.size()),requires_grad=True, dtype=DTYPE,device=Defaults.DEVICE))/100
        #loss = MSELoss()(xy, tt) / 10 + 0*MSELoss()(control, tc) / 1000 + abs(vel[-1][0]) + abs(vel[-1][1]) + abs(vel[-1][2])

        loss = MSELoss()(xy, tt) / 10 + 0 * MSELoss()(
            control, tc) / 1000 + abs(vel[-1][0]) + abs(vel[-1][1]) + abs(
                vel[-1][2])
        #loss = zhist[-1]
        loss.backward()

        lossHist.append(loss.item())
        optim.step()

        print('Loss:', loss.item())
        print('Gradient:', utT.grad)
        print('Next u:', utT)

        #plt.axis([-1, 1, -1, 1])
        plt.ion()
        plt.show()

        plt.subplot(211)

        plt.plot(lossHist)
        plt.draw()

        pl1 = xy.cpu()
        plt.subplot(212)

        plt.plot(pl1.clone().detach().numpy()[::4],
                 pl1.clone().detach().numpy()[1::4])
        #plt.plot(zhist)
        plt.draw()

        plt.pause(0.001)

        plt.savefig('2step.png')

        if (i % 20) == 0:
            plt.subplot(211)
            plt.cla()
            plt.subplot(212)
            plt.cla()
            plt.gca().invert_yaxis()
            lossHist.clear()
            torch.save(utT, timestampStr + '/file' + str(i) + '.pt')

        utTCurrent = utT.clone()

    world, chain = make_world(forces, ctT, utTCurrent)

    positions_run_world(world, run_time=runtime, screen=None, recorder=rec)