コード例 #1
0
def train(train_loader, val_loader, epochnum, save_path='.', save_freq=None):
    iter_size = len(train_loader)
    net = Encoder()
    net.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(net.parameters(),
                          lr=0.01,
                          momentum=0.9,
                          weight_decay=2e-4)

    for epoch in range(epochnum):
        print('epoch : {}'.format(epoch))
        net.train()
        train_loss = 0
        train_correct = 0
        total = 0
        net.training = True
        for i, data in enumerate(train_loader):
            sys.stdout.write('iter : {} / {}\r'.format(i, iter_size))
            sys.stdout.flush()
            #print('iter: {} / {}'.format(i, iter_size))
            inputs, labels = data
            inputs, labels = Variable(inputs.cuda()), labels.cuda()
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, Variable(labels))
            loss.backward()
            optimizer.step()
            train_loss += loss.data[0]
            pred = (torch.max(outputs.data, 1)[1])
            train_correct += (pred == labels).sum()
            total += labels.size(0)
        sys.stdout.write(' ' * 20 + '\r')
        sys.stdout.flush()

        print('train_loss:{}, train_acc:{:.2%}'.format(train_loss / total,
                                                       train_correct / total))
        val_loss = 0
        val_correct = 0
        total = 0
        net.training = False
        for data in val_loader:
            net.eval()
            inputs, labels = data
            inputs, labels = Variable(inputs).cuda(), labels.cuda()
            outputs = net(inputs)
            pred = torch.max(outputs.data, 1)[1]
            total += labels.size(0)
            loss = criterion(outputs, Variable(labels))
            val_loss += loss.data[0]
            val_correct += (pred == labels).sum()

        print('val_loss:{}, val_acc:{:.2%}'.format(val_loss / total,
                                                   val_correct / total))
        optimizer.param_groups[0]['lr'] *= np.exp(-0.4)
        if save_freq and epoch % save_freq == save_freq - 1:
            net_name = os.path.join(save_path, 'epoch_{}'.format(epoch))
            torch.save(net, net_name)
    torch.save(net, os.path.join(save_path, 'trained_net'))
コード例 #2
0
def load_encoder(obs_space, args, freeze=True):
    enc = Encoder(obs_space, args.dim,
                  use_conv=args.use_conv)
    enc_state = torch.load(args.dynamics_module, map_location=lambda storage,
                           loc: storage)['enc']
    enc.load_state_dict(enc_state)
    enc.eval()
    if freeze:
        for p in enc.parameters():
            p.requires_grad = False
    return enc
コード例 #3
0
def main(test_img_path):
    options = parse_args()
    is_cuda = use_cuda and not options.no_cuda
    hardware = "cuda" if is_cuda else "cpu"
    device = torch.device(hardware)

    for checkpoint_path in options.checkpoint:
        checkpoint_name, _ = os.path.splitext(
            os.path.basename(checkpoint_path))
        checkpoint = (load_checkpoint(checkpoint_path, cuda=is_cuda)
                      if checkpoint_path else default_checkpoint)
        encoder_checkpoint = checkpoint["model"].get("encoder")
        decoder_checkpoint = checkpoint["model"].get("decoder")

        test_img = Image.open(test_img_path)
        test_img = test_img.convert("RGB")

        enc = Encoder(img_channels=3, checkpoint=encoder_checkpoint).to(device)
        dec = Decoder(
            1,
            low_res_shape,
            high_res_shape,
            checkpoint=decoder_checkpoint,
            device=device,
        ).to(device)
        enc.eval()
        dec.eval()

        result = evaluate(
            enc,
            dec,
            test_img=test_img,
            device=device,
            checkpoint=checkpoint,
            beam_width=options.beam_width,
            prefix=options.prefix,
        )
        print(result)