示例#1
0
                   cell_size=args.mem_size,
                   sparse_reads=args.sparse_reads,
                   temporal_reads=args.temporal_reads,
                   read_heads=args.read_heads,
                   gpu_id=args.cuda,
                   debug=args.visdom,
                   batch_first=True,
                   independent_linears=False)
    elif args.memory_type == 'sam':
        rnn = SAM(input_size=args.input_size,
                  hidden_size=args.nhid,
                  rnn_type=args.rnn_type,
                  num_layers=args.nlayer,
                  num_hidden_layers=args.nhlayer,
                  dropout=args.dropout,
                  nr_cells=args.mem_slot,
                  cell_size=args.mem_size,
                  sparse_reads=args.sparse_reads,
                  read_heads=args.read_heads,
                  gpu_id=args.cuda,
                  debug=args.visdom,
                  batch_first=True,
                  independent_linears=False)
    else:
        raise Exception('Not recognized type of memory')

    if args.cuda != -1:
        rnn = rnn.cuda(args.cuda)

    print(rnn)

    last_save_losses = []
示例#2
0
文件: train.py 项目: zoharli/armin
          controller_size=args.lstm_size,
          memory_units=128,
          memory_unit_size=20,
          num_heads=1)#task_params['num_heads'])
elif args.model=='dnc':
    model = DNC(input_size= input_size,
          output_size=output_size,
          hidden_size=args.lstm_size,
          nr_cells=128,
          cell_size=20,
          read_heads=1)#task_params['num_heads'])
    model.init_param()
elif args.model=='sam':
    model = SAM(input_size= input_size,
          output_size=output_size,
          hidden_size=args.lstm_size,
          nr_cells=128,
          cell_size=20,
          read_heads=1)#read_heads=4???#task_params['num_heads'])
    model.init_param()
elif args.model=='lstm':
    marnn_config=args
    print('marnn_config:\n',marnn_config)
    model = MARNN(marnn_config,input_size=input_size,
            num_units=marnn_config.lstm_size,
            output_size=output_size,
            use_zoneout=False,
            use_ln=False)
else:
    has_tau=1
    marnn_config=args
    print('marnn_config:\n',marnn_config)