Ejemplo n.º 1
0
class ET_Net(nn.Module):
    """ET-Net: A Generic Edge-aTtention Guidance Network for Medical Image Segmentation
    """
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.egm = EdgeGuidanceModule()
        self.wam = WeightedAggregationModule()

    def forward(self, x):
        enc_1, enc_2, enc_3, enc_4 = self.encoder(x)
        dec_1, dec_2, dec_3 = self.decoder(enc_1, enc_2, enc_3, enc_4)
        edge_pred, egm = self.egm(enc_1, enc_2)
        pred = self.wam(dec_1, dec_2, dec_3, egm)
        return edge_pred, pred

    def load_encoder_weight(self):
        # One could get the pretrained weights via PyTorch official.
        self.encoder.load_state_dict(torch.load(ARGS['encoder_weight']))
Ejemplo n.º 2
0
def run_model(mode, path, in_file, o_file):
    global feature, encoder, indp, crf, mldecoder, rltrain, f_opt, e_opt, i_opt, c_opt, m_opt, r_opt

    cfg = Configuration()

    #General mode has two values: 'train' or 'test'
    cfg.mode = mode

    #Set Random Seeds
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    if hasCuda:
        torch.cuda.manual_seed_all(cfg.seed)

    #Load Embeddings
    load_embeddings(cfg)

    #Only for testing
    if mode == 'test': cfg.test_raw = in_file

    #Construct models
    feature = Feature(cfg)
    if cfg.model_type == 'AC-RNN':
        f_opt = optim.SGD(ifilter(lambda p: p.requires_grad,
                                  feature.parameters()),
                          lr=cfg.actor_step_size)
    else:
        f_opt = optim.Adam(ifilter(lambda p: p.requires_grad,
                                   feature.parameters()),
                           lr=cfg.learning_rate)

    if hasCuda: feature.cuda()

    encoder = Encoder(cfg)
    if cfg.model_type == 'AC-RNN':
        e_opt = optim.SGD(ifilter(lambda p: p.requires_grad,
                                  encoder.parameters()),
                          lr=cfg.actor_step_size)
    else:
        e_opt = optim.Adam(ifilter(lambda p: p.requires_grad,
                                   encoder.parameters()),
                           lr=cfg.learning_rate)
    if hasCuda: encoder.cuda()

    if cfg.model_type == 'INDP':
        indp = INDP(cfg)
        i_opt = optim.Adam(ifilter(lambda p: p.requires_grad,
                                   indp.parameters()),
                           lr=cfg.learning_rate)
        if hasCuda: indp.cuda()

    elif cfg.model_type == 'CRF':
        crf = CRF(cfg)
        c_opt = optim.Adam(ifilter(lambda p: p.requires_grad,
                                   crf.parameters()),
                           lr=cfg.learning_rate)
        if hasCuda: crf.cuda()

    elif cfg.model_type == 'TF-RNN':
        mldecoder = MLDecoder(cfg)
        m_opt = optim.Adam(ifilter(lambda p: p.requires_grad,
                                   mldecoder.parameters()),
                           lr=cfg.learning_rate)
        if hasCuda: mldecoder.cuda()
        cfg.mldecoder_type = 'TF'

    elif cfg.model_type == 'SS-RNN':
        mldecoder = MLDecoder(cfg)
        m_opt = optim.Adam(ifilter(lambda p: p.requires_grad,
                                   mldecoder.parameters()),
                           lr=cfg.learning_rate)
        if hasCuda: mldecoder.cuda()
        cfg.mldecoder_type = 'SS'

    elif cfg.model_type == 'AC-RNN':
        mldecoder = MLDecoder(cfg)
        m_opt = optim.SGD(ifilter(lambda p: p.requires_grad,
                                  mldecoder.parameters()),
                          lr=cfg.actor_step_size)
        if hasCuda: mldecoder.cuda()
        cfg.mldecoder_type = 'TF'
        rltrain = RLTrain(cfg)
        r_opt = optim.Adam(ifilter(lambda p: p.requires_grad,
                                   rltrain.parameters()),
                           lr=cfg.learning_rate,
                           weight_decay=0.001)
        if hasCuda: rltrain.cuda()
        cfg.rltrain_type = 'AC'
        #For RL, the network should be pre-trained with teacher forced ML decoder.
        feature.load_state_dict(torch.load(path + 'TF-RNN' + '_feature'))
        encoder.load_state_dict(torch.load(path + 'TF-RNN' + '_encoder'))
        mldecoder.load_state_dict(torch.load(path + 'TF-RNN' + '_predictor'))

    if mode == 'train':
        o_file = './temp.predicted_' + cfg.model_type
        best_val_cost = float('inf')
        best_val_epoch = 0
        first_start = time.time()
        epoch = 0
        while (epoch < cfg.max_epochs):
            print
            print 'Model:{} | Epoch:{}'.format(cfg.model_type, epoch)

            if cfg.model_type == 'SS-RNN':
                #Specify the decaying schedule for sampling probability.
                #inverse sigmoid schedule:
                cfg.sampling_p = float(
                    cfg.k) / float(cfg.k + np.exp(float(epoch) / cfg.k))

            start = time.time()
            run_epoch(cfg)
            print '\nValidation:'
            predict(cfg, o_file)
            val_cost = 100 - evaluate(cfg, cfg.dev_ref, o_file)
            print 'Validation score:{}'.format(100 - val_cost)
            if val_cost < best_val_cost:
                best_val_cost = val_cost
                best_val_epoch = epoch
                torch.save(feature.state_dict(),
                           path + cfg.model_type + '_feature')
                torch.save(encoder.state_dict(),
                           path + cfg.model_type + '_encoder')
                if cfg.model_type == 'INDP':
                    torch.save(indp.state_dict(),
                               path + cfg.model_type + '_predictor')
                elif cfg.model_type == 'CRF':
                    torch.save(crf.state_dict(),
                               path + cfg.model_type + '_predictor')
                elif cfg.model_type == 'TF-RNN' or cfg.model_type == 'SS-RNN':
                    torch.save(mldecoder.state_dict(),
                               path + cfg.model_type + '_predictor')
                elif cfg.model_type == 'AC-RNN':
                    torch.save(mldecoder.state_dict(),
                               path + cfg.model_type + '_predictor')
                    torch.save(rltrain.state_dict(),
                               path + cfg.model_type + '_critic')

            #For early stopping
            if epoch - best_val_epoch > cfg.early_stopping:
                break
                ###

            print 'Epoch training time:{} seconds'.format(time.time() - start)
            epoch += 1

        print 'Total training time:{} seconds'.format(time.time() -
                                                      first_start)

    elif mode == 'test':
        cfg.batch_size = 256
        feature.load_state_dict(torch.load(path + cfg.model_type + '_feature'))
        encoder.load_state_dict(torch.load(path + cfg.model_type + '_encoder'))
        if cfg.model_type == 'INDP':
            indp.load_state_dict(
                torch.load(path + cfg.model_type + '_predictor'))
        elif cfg.model_type == 'CRF':
            crf.load_state_dict(
                torch.load(path + cfg.model_type + '_predictor'))
        elif cfg.model_type == 'TF-RNN' or cfg.model_type == 'SS-RNN':
            mldecoder.load_state_dict(
                torch.load(path + cfg.model_type + '_predictor'))
        elif cfg.model_type == 'AC-RNN':
            mldecoder.load_state_dict(
                torch.load(path + cfg.model_type + '_predictor'))
            rltrain.load_state_dict(
                torch.load(path + cfg.model_type + '_critic'))

        print
        print 'Model:{} Predicting'.format(cfg.model_type)
        start = time.time()
        predict(cfg, o_file)
        print 'Total prediction time:{} seconds'.format(time.time() - start)
    return
Ejemplo n.º 3
0
channels = config["channel"]
alpha = config["alpha"]
csv_path = config["csv_path"]
img_dir = config["image_dir"]
output_dir = config["output_dir"]
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
data = pd.read_csv(csv_path)
paths = data["ImageId"].values
paths = [os.path.join(img_dir, p) for p in paths]
labels = data["TrueLabel"].values

encoder = Encoder(channels, out_ch=2048)
decoder = Decoder(2048, channels)

encoder.load_state_dict(torch.load(config["encoder"], map_location="cpu"))
decoder.load_state_dict(torch.load(config["decoder"], map_location="cpu"))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

encoder.eval()
decoder.eval()
x_adv = []
with torch.no_grad():
    bar = tqdm.tqdm(paths)
    for path in bar:
        filename = os.path.basename(path)
        bar.set_description(f"processing:{filename}")
        image = cv2.imread(path)
Ejemplo n.º 4
0
def inference(checkpoint_file, text):
    ds = tiny_words(max_text_length=hp.max_text_length,
                    max_audio_length=hp.max_audio_length,
                    max_dataset_size=args.data_size)

    print(ds.texts)

    # prepare input
    indexes = indexes_from_text(ds.lang, text)
    indexes.append(EOT_token)
    padded_indexes = pad_indexes(indexes, hp.max_text_length, PAD_token)
    texts_v = Variable(torch.from_numpy(padded_indexes))
    texts_v = texts_v.unsqueeze(0)

    if hp.use_cuda:
        texts_v = texts_v.cuda()

    encoder = Encoder(ds.lang.num_chars,
                      hp.embedding_dim,
                      hp.encoder_bank_k,
                      hp.encoder_bank_ck,
                      hp.encoder_proj_dims,
                      hp.encoder_highway_layers,
                      hp.encoder_highway_units,
                      hp.encoder_gru_units,
                      dropout=hp.dropout,
                      use_cuda=hp.use_cuda)

    decoder = AttnDecoder(hp.max_text_length,
                          hp.attn_gru_hidden_size,
                          hp.n_mels,
                          hp.rf,
                          hp.decoder_gru_hidden_size,
                          hp.decoder_gru_layers,
                          dropout=hp.dropout,
                          use_cuda=hp.use_cuda)

    postnet = PostNet(hp.n_mels,
                      1 + hp.n_fft // 2,
                      hp.post_bank_k,
                      hp.post_bank_ck,
                      hp.post_proj_dims,
                      hp.post_highway_layers,
                      hp.post_highway_units,
                      hp.post_gru_units,
                      use_cuda=hp.use_cuda)

    encoder.eval()
    decoder.eval()
    postnet.eval()

    if hp.use_cuda:
        encoder.cuda()
        decoder.cuda()
        postnet.cuda()

    # load model
    checkpoint = torch.load(checkpoint_file)
    encoder.load_state_dict(checkpoint['encoder'])
    decoder.load_state_dict(checkpoint['decoder'])
    postnet.load_state_dict(checkpoint['postnet'])

    encoder_out = encoder(texts_v)

    # Prepare input and output variables
    GO_frame = np.zeros((1, hp.n_mels))
    decoder_in = Variable(torch.from_numpy(GO_frame).float())
    if hp.use_cuda:
        decoder_in = decoder_in.cuda()
    h, hs = decoder.init_hiddens(1)

    decoder_outs = []
    for t in range(int(hp.max_audio_length / hp.rf)):
        decoder_out, h, hs, _ = decoder(decoder_in, h, hs, encoder_out)
        decoder_outs.append(decoder_out)
        # use predict
        decoder_in = decoder_out[:, -1, :].contiguous()

    # (batch_size, T, n_mels)
    decoder_outs = torch.cat(decoder_outs, 1)

    # postnet
    post_out = postnet(decoder_outs)
    s = post_out[0].cpu().data.numpy()

    print("Recontructing wav...")
    s = np.where(s < 0, 0, s)
    wav = spectrogram2wav(s**hp.power)
    # wav = griffinlim(s**hp.power)
    write("demo.wav", hp.sr, wav)