Beispiel #1
0
BSIZE = 8
SEQ_LEN = 999
epochs = 3000
torch.backends.cudnn.benchmark = True
learning_rate = 1e-4
# Loading model

rnn_dir = join(args.logdir, 'mdrnn')
rnn_file = join(rnn_dir, 'best.tar')

if not exists(rnn_dir):
    mkdir(rnn_dir)

mdrnn = MDRNN(LSIZE, ASIZE, RSIZE, 5)
mdrnn = torch.nn.DataParallel(mdrnn, device_ids=[1, 2, 3, 4, 5, 6, 7])
mdrnn.cuda(1)
#mdrnn.to(device)
optimizer = optim.Adam(mdrnn.parameters(), lr=1e-4, betas=(0.9, 0.999))
# scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
# earlystopping = EarlyStopping('min', patience=30)

if exists(rnn_file) and not args.noreload:
    rnn_state = torch.load(rnn_file)
    print("Loading MDRNN at epoch {} "
          "with test error {}".format(rnn_state["epoch"],
                                      rnn_state["precision"]))
    mdrnn.load_state_dict(rnn_state["state_dict"])
    optimizer.load_state_dict(rnn_state["optimizer"])
    # scheduler.load_state_dict(state['scheduler'])
    # earlystopping.load_state_dict(state['earlystopping'])
Beispiel #2
0
BSIZE = 8
SEQ_LEN = 999
epochs = 3000
torch.backends.cudnn.benchmark = True
learning_rate = 1e-4
# Loading model

rnn_dir = join(args.logdir, 'mdrnn')
rnn_file = join(rnn_dir, 'best.tar')

if not exists(rnn_dir):
    mkdir(rnn_dir)

mdrnn = MDRNN(LSIZE, ASIZE, RSIZE, 5)
mdrnn = torch.nn.DataParallel(mdrnn, device_ids=range(8))
mdrnn.cuda()
#mdrnn.to(device)
optimizer = optim.Adam(mdrnn.parameters(), lr=1e-4, betas=(0.9, 0.999))
# scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
# earlystopping = EarlyStopping('min', patience=30)

if exists(rnn_file) and not args.noreload:
    rnn_state = torch.load(rnn_file)
    print("Loading MDRNN at epoch {} "
          "with test error {}".format(rnn_state["epoch"],
                                      rnn_state["precision"]))
    mdrnn.load_state_dict(rnn_state["state_dict"])
    optimizer.load_state_dict(rnn_state["optimizer"])
    # scheduler.load_state_dict(state['scheduler'])
    # earlystopping.load_state_dict(state['earlystopping'])