Ejemplo n.º 1
0
        print('Create Model!')
    som = som.to(device)

    if train == True:
        losses = list()
        for epoch in range(total_epoch):
            running_loss = 0
            start_time = time.time()
            for idx, (X, Y) in enumerate(train_loader):
                X = X.view(-1, 28 * 28 * 1).to(device)  # flatten
                loss = som.self_organizing(X, epoch, total_epoch)  # train som
                running_loss += loss

            losses.append(running_loss)
            print('epoch = %d, loss = %.2f, time = %.2fs' %
                  (epoch + 1, running_loss, time.time() - start_time))

            if epoch % 5 == 0:
                # save
                som.save_result('%s/som_epoch_%d.png' % (RES_DIR, epoch),
                                (1, 28, 28))
                torch.save(som.state_dict(), '%s/som.pth' % MODEL_DIR)

        torch.save(som.state_dict(), '%s/som.pth' % MODEL_DIR)
        plt.title('SOM loss')
        plt.plot(losses)
        plt.show()

    som.save_result('%s/som_result.png' % (RES_DIR), (1, 28, 28))
    torch.save(som.state_dict(), '%s/som.pth' % MODEL_DIR)
Ejemplo n.º 2
0
        if train == True:
            losses = list()
            for epoch in tqdm(range(total_epoch)):
                running_loss = 0
                start_time = time.time()
                for idx, b in enumerate(train_dataloader):
                    X=torch.FloatTensor(b['mol_fp'])
                    #print(f'X.size{X.size()}')
                    X = X.view(-1, input_dim).to(device)    # flatten
                    loss = som.self_organizing(X, epoch, total_epoch)    # train som
                    running_loss += loss
                losses.append(running_loss)
                print('epoch = %d, loss = %.2f, time = %.2fs' % (epoch + 1, running_loss, time.time() - start_time))
                if epoch % 100 == 0:
                    # save TODO som.save_result for molecules save
                    torch.save(som.state_dict(), f'{args.save_dir}/som_F{f_i}Epoch{epoch}.pth')
            torch.save(som.state_dict(), f'{args.save_dir}/som_F{f_i}.pth')
            #plot loss trend for each fold training
            plt.title('SOM loss')
            plt.plot(losses)
            #plt.show()
            plt.savefig(f'{args.save_dir}/train_loss{f_i}.png')
            plt.clf()#https://stackoverflow.com/questions/21884271/warning-about-too-many-open-figures
            plt.close()

            save_f_val = args.save_dir + '/' + f'val_heatmapF{f_i}.png'
            save_f_test = args.save_dir + '/' + f'test_heatmapF{f_i}.png'
            # testing   and pliot heatmap each fold
            x_val=torch.FloatTensor(val)
            x_val = x_val.view(-1, input_dim).to(device)
            val_clusters = predict(som,x_val)