Exemplo n.º 1
0
    # filename = './data/data_'+str(args.channel)+'_lr_'+str(args.enc_lr)+'_D'+str(args.D)+'_'+str(args.num_block)+'.txt'

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print("use_cuda: ", use_cuda)
    device = torch.device("cuda" if use_cuda else "cpu")

    #################################################
    # Setup Channel AE: Encoder, Decoder, Channel
    #################################################

    encoder = ENC(args)
    decoder = DEC(args)

    # choose support channels
    from channel_ae import Channel_AE
    model = Channel_AE(args, encoder, decoder).to(device)

    # weight loading
    if args.init_nw_weight == 'default':
        pass

    else:
        pretrained_model = torch.load(args.init_nw_weight)

        try:
            model.load_state_dict(pretrained_model.state_dict(), strict=False)

        except:
            model.load_state_dict(pretrained_model, strict=False)

        model.args = args
Exemplo n.º 2
0
    elif args.is_interleave == 0:
        p_array = range(args.block_len)  # no interleaver.
    else:
        seed = np.random.randint(0, args.is_interleave)
        rand_gen = mtrand.RandomState(seed)
        p_array = rand_gen.permutation(arange(args.block_len))

        print('using random interleaver', p_array)

    encoder = ENC(args, p_array)
    decoder = DEC(args, p_array)

    # choose support channels
    from channel_ae import Channel_AE
    model = Channel_AE(args, encoder, decoder).to(device)

    # make the model parallel
    if args.is_parallel == 1:
        model.enc.set_parallel()
        model.dec.set_parallel()

    # weight loading
    if args.init_nw_weight == 'default':
        pass

    else:
        pretrained_model = torch.load(args.init_nw_weight)

        try:
            model.load_state_dict(pretrained_model.state_dict(), strict=False)
Exemplo n.º 3
0
        seed = np.random.randint(0, args.is_interleave)
        rand_gen = mtrand.RandomState(seed)
        p_array2 = rand_gen.permutation(arange(args.block_len))

    print('using random interleaver', p_array1, p_array2)

    if args.encoder == 'turboae_2int' and args.decoder == 'turboae_2int':
        encoder = ENC(args, p_array1, p_array2)
        decoder = DEC(args, p_array1, p_array2)
    else:
        encoder = ENC(args, p_array1)
        decoder = DEC(args, p_array1)

    # choose support channels
    from channel_ae import Channel_AE
    model = Channel_AE(args, encoder, decoder).to(device)
    # from channel_ae import Channel_ModAE
    # model = Channel_ModAE(args, encoder, decoder).to(device)


    # make the model parallel
    if args.is_parallel == 1:
        model.enc.set_parallel()
        model.dec.set_parallel()

    # weight loading
    if args.init_nw_weight == 'default':
        pass

    else:
        pretrained_model = torch.load(args.init_nw_weight,map_location=torch.device('cpu'))