예제 #1
0
                pred=torch.max(pout,dim=1)[1]
                loss = cost(pout, lab.long())

                err = torch.mean((pred!=lab.long()).float())
    
                [val,best_class]=torch.max(torch.sum(pout,dim=0),0)
                err_sum_snt=err_sum_snt+(best_class!=lab[0]).float()
    
    
                loss_sum=loss_sum+loss.detach()
                err_sum=err_sum+err.detach()
    
            err_tot_dev_snt=err_sum_snt/snt_te
            loss_tot_dev=loss_sum/snt_te
            err_tot_dev=err_sum/snt_te

  
        print("epoch %i, loss_tr=%f err_tr=%f loss_te=%f err_te=%f err_te_snt=%f" % (epoch, loss_tot,err_tot,loss_tot_dev,err_tot_dev,err_tot_dev_snt))
  
        with open(output_folder+"/res.res", "a") as res_file:
            res_file.write("epoch %i, loss_tr=%f err_tr=%f loss_te=%f err_te=%f err_te_snt=%f\n" % (epoch, loss_tot,err_tot,loss_tot_dev,err_tot_dev,err_tot_dev_snt))

        checkpoint={'CNN_model_par': CNN_net.state_dict(),
               'DNN1_model_par': DNN1_net.state_dict(),
               'DNN2_model_par': DNN2_net.state_dict(),
               }
        torch.save(checkpoint,output_folder+'/model_raw_'+ str(epoch) +'.pkl')
  
    else:
        print("epoch %i / %i, loss_tr=%f err_tr=%f" % (epoch, N_epochs, loss_tot,err_tot))
예제 #2
0
                [val, best_class] = torch.max(torch.sum(pout, dim=0), 0)
                err_sum_snt = err_sum_snt + (best_class != lab[0]).float()

                loss_sum = loss_sum + loss.detach()
                err_sum = err_sum + err.detach()

            err_tot_dev_snt = err_sum_snt / snt_te
            loss_tot_dev = loss_sum / snt_te
            err_tot_dev = err_sum / snt_te

        print(
            "epoch %i, loss_tr=%f err_tr=%f loss_te=%f err_te=%f err_te_snt=%f"
            % (epoch, loss_tot, err_tot, loss_tot_dev, err_tot_dev,
               err_tot_dev_snt))

        with open(output_folder + "/res.res", "a") as res_file:
            res_file.write(
                "epoch %i, loss_tr=%f err_tr=%f loss_te=%f err_te=%f err_te_snt=%f\n"
                % (epoch, loss_tot, err_tot, loss_tot_dev, err_tot_dev,
                   err_tot_dev_snt))

        checkpoint = {
            'CNN_model_par': CNN_net.state_dict(),
            'DNN1_model_par': DNN1_net.state_dict(),
            'DNN2_model_par': DNN2_net.state_dict(),
        }
        torch.save(checkpoint, output_folder + '/model_raw.pkl')

    else:
        print("epoch %i, loss_tr=%f err_tr=%f" % (epoch, loss_tot, err_tot))