Exemplo n.º 1
0
                          mode='encoder')

            if args.num_train_dec > 0:
                for idx in range(args.num_train_dec):
                    train(epoch,
                          model,
                          dec_optimizer,
                          args,
                          use_cuda=use_cuda,
                          mode='decoder')

        this_loss, this_ber = validate(model,
                                       general_optimizer,
                                       args,
                                       use_cuda=use_cuda)
        report_loss.append(this_loss)
        report_ber.append(this_ber)

    if args.print_test_traj == True:
        print('test loss trajectory', report_loss)
        print('test ber trajectory', report_ber)
        print('total epoch', args.num_epoch)

    #################################################
    # Testing Processes
    #################################################
    test(model, args, use_cuda=use_cuda)

    torch.save(model.state_dict(), './tmp/torch_model_' + identity + '.pt')
    print('saved model', './tmp/torch_model_' + identity + '.pt')
Exemplo n.º 2
0
        report_loss.append(this_loss)
        report_ber.append(this_ber)
        report_bler.append(this_bler)

        data_file = open(filename, 'a')
        data_file.write(
            str(epoch) + ' ' + str(this_loss) + ' ' + str(this_ber) + ' ' +
            str(this_bler) + "\n")
        data_file.close()

        # save model per epoch
        modelpath = './tmp/attention_model_' + str(epoch) + '_' + str(
            args.channel) + '_lr_' + str(args.enc_lr) + '_D' + str(
                args.D) + '_' + str(args.num_block) + '_' + timestamp + '.pt'
        # modelpath = './tmp/model_'+str(epoch)+'_'+str(args.channel)+'_lr_'+str(args.enc_lr)+'_D'+str(args.D)+'_'+str(args.num_block)+'.pt'
        torch.save(model.state_dict(), modelpath)
        # try:
        #     # pre_modelpath = './tmp/model_'+str(epoch-1)+'_'+str(args.channel)+'_lr_'+str(args.enc_lr)+'_D'+str(args.D)+'_'+str(args.num_block)+'_'+timestamp+'.pt'
        #     pre_modelpath = './tmp/model_'+str(epoch-1)+'_'+str(args.channel)+'_lr_'+str(args.enc_lr)+'_D'+str(args.D)+'_'+str(args.num_block)+'.pt'
        #     os.system(r"rm -f {}".format(pre_modelpath))#调用系统命令行来删除文件
        # except:
        #     pass
        print('saved model', modelpath)
        print("each epoch training time: {}s".format(time.time() -
                                                     epoch_start_time))

    if args.print_test_traj == True:
        print('test loss trajectory', report_loss)
        print('test ber trajectory', report_ber)
        print('test bler trajectory', report_bler)
        print('total epoch', args.num_epoch)