def main(args):
    train_doc,train_summ,test_doc, \
    test_summ,val_doc,val_summ=read_from_file(args.data_file)
    vocab = read_vocab(args.path_to_vocab)

    embeddings = Embeddings(args.embed_size, args.vocab_size).cuda()
    encoder = Encoder(args.embed_size, args.hidden_size).cuda()
    decoder = Decoder(args.embed_size, args.hidden_size,
                      args.vocab_size).cuda()
    generator = BeamSearch(vocab, args.max_decode_len, args.min_decode_len,
                           args.beam_size)

    trainloader = DataLoader(train_doc, train_summ, vocab, args.batch_size,
                             args.max_doc_len, args.max_summ_len)
    testloader = Test_loader(test_doc, test_summ, vocab, args.max_doc_len,
                             args.max_summ_test_len)
    valloader = Test_loader(val_doc, val_summ, vocab, args.max_doc_len,
                            args.max_summ_test_len)

    if args.use_pretrained:

        params = torch.load(args.pretrained_model)
        embeddings.load_state_dict(params['embed_params'])
        encoder.load_state_dict(params['encoder_params'])
        decoder.load_state_dict(params['decoder_params'])

        test(embeddings, encoder, decoder, testloader, generator, args.lambda_)

    train(embeddings, encoder, decoder, generator, trainloader, valloader,
          args.iterations, args.lambda_, args.lr, args.max_grad_norm,
          args.initial_accum_val, args.threshold)

    test(embeddings, encoder, decoder, testloader, generator, args.lambda_)
Exemple #2
0
def main():
    parser = argparse.ArgumentParser(description='Test learned model')
    parser.add_argument('dir',
                        type=str,
                        help='log directory to load learned model')
    parser.add_argument('--render', action='store_true')
    parser.add_argument('--domain-name', type=str, default='cheetah')
    parser.add_argument('--task-name', type=str, default='run')
    parser.add_argument('-R', '--action-repeat', type=int, default=2)
    parser.add_argument('--episodes', type=int, default=1)
    args = parser.parse_args()

    # define environment and apply wrapper
    env = suite.load(args.domain_name, args.task_name)
    env = pixels.Wrapper(env,
                         render_kwargs={
                             'height': 64,
                             'width': 64,
                             'camera_id': 0
                         })
    env = GymWrapper(env)
    env = RepeatAction(env, skip=args.action_repeat)

    # define models
    with open(os.path.join(args.dir, 'args.json'), 'r') as f:
        train_args = json.load(f)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    encoder = Encoder().to(device)
    rssm = RecurrentStateSpaceModel(train_args['state_dim'],
                                    env.action_space.shape[0],
                                    train_args['rnn_hidden_dim']).to(device)
    action_model = ActionModel(train_args['state_dim'],
                               train_args['rnn_hidden_dim'],
                               env.action_space.shape[0]).to(device)

    # load learned parameters
    encoder.load_state_dict(torch.load(os.path.join(args.dir, 'encoder.pth')))
    rssm.load_state_dict(torch.load(os.path.join(args.dir, 'rssm.pth')))
    action_model.load_state_dict(
        torch.load(os.path.join(args.dir, 'action_model.pth')))

    # define agent
    policy = Agent(encoder, rssm, action_model)

    # test learnged model in the environment
    for episode in range(args.episodes):
        policy.reset()
        obs = env.reset()
        done = False
        total_reward = 0
        while not done:
            action = policy(obs)
            obs, reward, done, _ = env.step(action)
            total_reward += reward
            if args.render:
                env.render(height=256, width=256, camera_id=0)

        print('Total test reward at episode [%4d/%4d] is %f' %
              (episode + 1, args.episodes, total_reward))
def get_encoder(latent_dim, fckpt='', ker_size=11):
    E = Encoder(z_dim=latent_dim, first_filter_size=ker_size)
    if fckpt and os.path.exists(fckpt):

        ckpt = torch.load(fckpt)
        loaded_sd = ckpt['E']
        try:
            E.load_state_dict(loaded_sd)
        except:
            curr_params = E.state_dict()
            curr_keys = list(curr_params.keys())

            updated_params = {}
            for k, v in loaded_sd.items():
                if 'bn7' in k:
                    newk = k.replace('bn7', 'conv7')
                else:
                    newk = k
                if newk in curr_keys and loaded_sd[k].shape == curr_params[
                        newk].shape:
                    updated_params[newk] = v
                else:
                    print('Failed to load:', k)
            curr_params.update(updated_params)
            E.load_state_dict(curr_params)
    return E
Exemple #4
0
def instantiate_model(config, tokenizer):
    configure_devices(config)
    model = Model(config)
    optimizer = transformers.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=0)
    metrics = None

    if config.continue_training:
        state_dict = torch.load(config.continue_training, map_location='cpu')
        model.load_state_dict(state_dict['model'])
        if 'optimizer_state_dict' in state_dict:
            optimizer.load_state_dict(state_dict['optimizer_state_dict'])
            for g in optimizer.param_groups:
                g['lr'] = config.learning_rate
        
        try:
            print(f"Loaded model:\nEpochs: {state_dict['epoch']}\nLoss: {state_dict['loss']}\n", 
                  f"Recall: {state_dict['rec']}\nMRR: {state_dict['mrr']}")
        except:
            pass
        
    if config.use_cuda:
        model = model.cuda()
        optimizer_to(optimizer, config.device)
        model = torch.nn.DataParallel(model, device_ids=config.devices)
    return model, optimizer, metrics
Exemple #5
0
class Model:
    def __init__(self, chpt_enc_path, chpt_dec_path, chpt_stat_path):
        historyLength = 10

        encoder_dim = hiddenDimension
        lstm_input_dim = historyLength + 1
        decoder_dim = hiddenDimension
        attention_dim = hiddenDimension
        output_dim = 1

        self.decodeLength = 20

        self.encoder = Encoder()
        self.decoder = DecoderWithAttention(encoder_dim, lstm_input_dim, decoder_dim, attention_dim, output_dim)

        self.encoder.load_state_dict(torch.load(chpt_enc_path))
        self.decoder.load_state_dict(torch.load(chpt_dec_path))

        self.encoder = self.encoder.to(device)
        self.decoder = self.decoder.to(device)

        self.encoder.eval()
        self.decoder.eval()

        with open(chpt_stat_path, 'rb') as f:
            chpt_stat = pickle.load(f)

        self.cMean = chpt_stat['cMean_tr']
        self.cStd = chpt_stat['cStd_tr']

        self.vMean = chpt_stat['vMean_tr']
        self.vStd = chpt_stat['vStd_tr']

        self.aMean = chpt_stat['aMean_tr']
        self.aStd = chpt_stat['aStd_tr']

        self.mean = torch.Tensor([self.vMean, self.aMean]).to(device)
        self.std = torch.Tensor([self.vStd, self.aStd]).to(device)

    def predict(self, curvatures, currentSpeed, histSpeeds, currentAccelX, histAccelXs):
        curvatures = torch.FloatTensor(curvatures).to(device)

        currentSpeed = torch.FloatTensor([currentSpeed]).to(device)
        histSpeeds = torch.FloatTensor(histSpeeds).to(device)

        currentAccelX = torch.FloatTensor([currentAccelX]).to(device)
        histAccelXs = torch.FloatTensor(histAccelXs).to(device)

        curvatures = (curvatures - self.cMean) / self.cStd
        currentSpeed = (currentSpeed - self.vMean) / self.vStd
        histSpeeds = (histSpeeds - self.vMean) / self.vStd
        currentAccelX = (currentAccelX - self.aMean) / self.aStd
        histAccelXs = (histAccelXs - self.aMean) / self.aStd

        curvatures = self.encoder(curvatures.unsqueeze(dim=0).unsqueeze(dim=0))
        predictions, alphas, alphas_target = self.decoder(curvatures, currentSpeed, histSpeeds.unsqueeze(dim=0), currentAccelX, histAccelXs.unsqueeze(dim=0),
                                    self.decodeLength, self.vMean, self.vStd, self.aMean, self.aStd)

        return (predictions.squeeze()*self.aStd + self.aMean).cpu().detach().numpy(), alphas.squeeze().cpu().detach().numpy()
Exemple #6
0
def main(args):
    torch.multiprocessing.set_start_method('spawn')
    torch.distributed.init_process_group(backend="nccl")

    with open(args.config_path, 'r') as file:
        config = AttrDict(json.load(file))

    set_seed(config.seed + torch.distributed.get_rank())

    train_data_csv, test_data_csv = train_test_split(
        config.train_data_csv_path, config.n_test_experiments)

    train_image_ids, train_labels = get_data(train_data_csv, is_train=True)
    train_transform = TrainTransform(config.crop_size)
    train_dataset = CellsDataset(config.train_images_dir, train_image_ids,
                                 train_labels, train_transform)

    test_image_ids, test_labels = get_data(test_data_csv, is_train=True)
    test_dataset = CellsDataset(config.train_images_dir, test_image_ids,
                                test_labels)

    if torch.distributed.get_rank() == 0:
        print(
            f'Train size: {len(train_dataset)}, test_size: {len(test_dataset)}'
        )

    encoder = Encoder(config.n_image_channels, config.n_emedding_channels,
                      config.n_classes, config.encoder_model,
                      config.encoder_pretrained, config.encoder_dropout,
                      config.encoder_scale)

    if config.restore_checkpoint_path is not None:
        state_dict = torch.load(config.restore_checkpoint_path,
                                map_location='cpu')
        encoder.load_state_dict(state_dict, strict=False)

    decoder = Decoder(config.n_emedding_channels, config.n_image_channels,
                      config.n_classes, config.decoder_n_channels)

    trainer = Trainer(encoder=encoder,
                      decoder=decoder,
                      optimizer_params={
                          'lr': config.lr,
                          'weight_decay': config.weight_decay,
                          'warmap': config.warmap,
                          'amsgrad': config.amsgrad
                      },
                      amp_params={
                          'opt_level': config.opt_level,
                          'loss_scale': config.loss_scale
                      },
                      rank=args.local_rank,
                      n_jobs=config.n_jobs)
    trainer.train(train_data=train_dataset,
                  n_epochs=config.n_epochs,
                  batch_size=config.batch_size,
                  test_data=test_dataset,
                  best_checkpoint_path=config.best_checkpoint_path)
def train():
    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    vocab_size = len(vocab)
    print('vocab_size:', vocab_size)

    dataloader = get_loader(image_dir,
                            caption_path,
                            vocab,
                            batch_size,
                            crop_size,
                            shuffle=True,
                            num_workers=num_workers)

    encoder = Encoder(embedding_size).to(device)
    decoder = Decoder(vocab_size, embedding_size, lstm_size).to(device)
    if os.path.exists(encoder_path):
        encoder.load_state_dict(torch.load(encoder_path))
    if os.path.exists(decoder_path):
        decoder.load_state_dict(torch.load(decoder_path))

    loss_fn = torch.nn.CrossEntropyLoss()
    parameters = list(encoder.fc.parameters()) + list(
        encoder.bn.parameters()) + list(decoder.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 lr=learning_rate,
                                 betas=(0.9, 0.99))

    num_steps = len(dataloader)
    for epoch in range(num_epochs):
        for index, (imgs, captions, lengths) in enumerate(dataloader):
            imgs = imgs.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(
                captions, lengths,
                batch_first=True)[0]  # the tailing [0] is necessary

            features = encoder(imgs)
            y_predicted = decoder(features, captions, lengths)
            loss = loss_fn(y_predicted, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if index % log_every == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch, num_epochs, index, num_steps, loss.item(),
                            np.exp(loss.item())))

            if index % save_every == 0 and index != 0:
                print('Start saving encoder')
                torch.save(encoder.state_dict(), encoder_path)
                print('Start saving decoder')
                torch.save(decoder.state_dict(), decoder_path)
Exemple #8
0
def convert(cfg):
    dataset_path = Path(utils.to_absolute_path("datasets")) / cfg.dataset.path
    with open(dataset_path / "speakers.json") as file:
        speakers = sorted(json.load(file))

    synthesis_list_path = Path(utils.to_absolute_path(cfg.synthesis_list))
    with open(synthesis_list_path) as file:
        synthesis_list = json.load(file)

    in_dir = Path(utils.to_absolute_path(cfg.in_dir))
    out_dir = Path(utils.to_absolute_path(cfg.out_dir))
    out_dir.mkdir(exist_ok=True, parents=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    encoder = Encoder(**cfg.model.encoder)
    decoder = Decoder(**cfg.model.decoder)
    encoder.to(device)
    decoder.to(device)

    print("Load checkpoint from: {}:".format(cfg.checkpoint))
    checkpoint_path = utils.to_absolute_path(cfg.checkpoint)
    checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
    encoder.load_state_dict(checkpoint["encoder"])
    decoder.load_state_dict(checkpoint["decoder"])

    encoder.eval()
    decoder.eval()

    for wav_path, speaker_id, out_filename in tqdm(synthesis_list):
        wav_path = in_dir / wav_path
        wav, _ = librosa.load(
            wav_path.with_suffix(".wav"),
            sr=cfg.preprocessing.sr)
        wav = wav / np.abs(wav).max() * 0.999

        mel = librosa.feature.melspectrogram(
            preemphasis(wav, cfg.preprocessing.preemph),
            sr=cfg.preprocessing.sr,
            n_fft=cfg.preprocessing.n_fft,
            n_mels=cfg.preprocessing.n_mels,
            hop_length=cfg.preprocessing.hop_length,
            win_length=cfg.preprocessing.win_length,
            fmin=cfg.preprocessing.fmin,
            power=1)
        logmel = librosa.amplitude_to_db(mel, top_db=cfg.preprocessing.top_db)
        logmel = logmel / cfg.preprocessing.top_db + 1

        mel = torch.FloatTensor(logmel).unsqueeze(0).to(device)
        speaker = torch.LongTensor([speakers.index(speaker_id)]).to(device)
        with torch.no_grad():
            z, _ = encoder.encode(mel)
            output = decoder.generate(z, speaker)

        path = out_dir / out_filename
        librosa.output.write_wav(path.with_suffix(".wav"), output.astype(np.float32), sr=cfg.preprocessing.sr)
def encode_dataset(cfg):
    out_dir = Path(utils.to_absolute_path(cfg.out_dir))
    out_dir.mkdir(exist_ok=True, parents=True)

    root_path = Path(utils.to_absolute_path("datasets")) / cfg.dataset.path
    with open(root_path / "test.json") as file:
        metadata = json.load(file)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    encoder = Encoder(**cfg.model.encoder)
    encoder.to(device)

    print("Load checkpoint from: {}:".format(cfg.checkpoint))
    checkpoint_path = utils.to_absolute_path(cfg.checkpoint)
    checkpoint = torch.load(checkpoint_path,
                            map_location=lambda storage, loc: storage)
    encoder.load_state_dict(checkpoint["encoder"])

    encoder.eval()

    if cfg.save_auxiliary:
        auxiliary = []

        def hook(module, input, output):
            auxiliary.append(output.clone())

        encoder.encoder[-1].register_forward_hook(hook)

    for _, _, _, path in tqdm(metadata):
        path = root_path.parent / path
        mel = torch.from_numpy(np.load(
            path.with_suffix(".mel.npy"))).unsqueeze(0).to(device)
        with torch.no_grad():
            z, c, indices = encoder.encode(mel)

        z = z.squeeze().cpu().numpy()

        out_path = out_dir / path.stem
        with open(out_path.with_suffix(".txt"), "w") as file:
            np.savetxt(file, z, fmt="%.16f")

        if cfg.save_auxiliary:
            aux_path = out_dir.parent / "auxiliary_embedding1"
            aux_path.mkdir(exist_ok=True, parents=True)
            out_path = aux_path / path.stem
            c = c.squeeze().cpu().numpy()
            with open(out_path.with_suffix(".txt"), "w") as file:
                np.savetxt(file, c, fmt="%.16f")

            aux_path = out_dir.parent / "auxiliary_embedding2"
            aux_path.mkdir(exist_ok=True, parents=True)
            out_path = aux_path / path.stem
            aux = auxiliary.pop().squeeze().cpu().numpy()
            with open(out_path.with_suffix(".txt"), "w") as file:
                np.savetxt(file, aux, fmt="%.16f")
def main():
    args = check_argv()

    # Code indices
    code_indices_fn = Path(args.code_indices_fn)
    print("Reading: {}".format(code_indices_fn))
    code_indices = np.loadtxt(code_indices_fn, dtype=np.int)

    # Speakers
    with open(Path("datasets/2019/english/speakers.json")) as f:
        speakers = sorted(json.load(f))

    # Model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder = Encoder(in_channels=80,
                      channels=768,
                      n_embeddings=512,
                      embedding_dim=64,
                      jitter=0.5)
    decoder = Decoder(
        in_channels=64,
        conditioning_channels=128,
        n_speakers=102,
        speaker_embedding_dim=64,
        mu_embedding_dim=256,
        rnn_channels=896,
        fc_channels=256,
        bits=8,
        hop_length=160,
    )
    decoder.to(device)

    print("Reading: {}".format(args.checkpoint))
    checkpoint_path = args.checkpoint
    checkpoint = torch.load(checkpoint_path,
                            map_location=lambda storage, loc: storage)
    encoder.load_state_dict(checkpoint["encoder"])
    decoder.load_state_dict(checkpoint["decoder"])
    encoder.eval()
    decoder.eval()

    # Codes
    embedding = encoder.codebook.embedding.cpu().numpy()
    codes = np.array([embedding[code_indices]])

    # Synthesize
    z = torch.FloatTensor(codes).to(device)
    speaker = torch.LongTensor([speakers.index(args.speaker)]).to(device)
    with torch.no_grad():
        output = decoder.generate(z, speaker)

    wav_fn = Path(code_indices_fn.stem).with_suffix(".wav")
    print("Writing: {}".format(wav_fn))
    librosa.output.write_wav(wav_fn, output.astype(np.float32), sr=16000)
Exemple #11
0
def load_encoder(obs_space, args, freeze=True):
    enc = Encoder(obs_space, args.dim,
                  use_conv=args.use_conv)
    enc_state = torch.load(args.dynamics_module, map_location=lambda storage,
                           loc: storage)['enc']
    enc.load_state_dict(enc_state)
    enc.eval()
    if freeze:
        for p in enc.parameters():
            p.requires_grad = False
    return enc
def load_encoder(data_root, weight_path, device):
    encoder = Encoder()
    if weight_path:
        weight = torch.load(weight_path)
    else:
        weight = torch.load(get_best_weight(data_root))
    encoder.load_state_dict(weight)

    if device >= 0:
        encoder = encoder.to(f"cuda:{device}")
    encoder.eval()
    return encoder
Exemple #13
0
class AMDIMEncoder(nn.Module):
    def __init__(self, state_dict):
        super().__init__()

        config = {
            "ndf": 128,
            "num_channels": 12,
            "n_rkhs": 1024,
            "n_depth": 3,
            "encoder_size": 128,
            "use_bn": 0,
        }

        dummy_batch = torch.zeros(
            (2, config['num_channels'], config['encoder_size'],
             config['encoder_size']))

        self.encoder = Encoder(dummy_batch, **config)

        state_dict = {k: v for k, v in state_dict.items() if 'encoder.' in k}
        state_dict = {
            k.replace('encoder.', '').replace('module.', ''): v
            for k, v in state_dict.items()
        }
        self.encoder.load_state_dict(state_dict)

        self.transform = BENTransformValid()

    def forward(self, x):
        assert len(x.shape) == 4, "Input must be (batch_size, 12, 128, 128)"
        assert x.shape[1] == 12, "Input must be (batch_size, 12, 128, 128)"
        assert x.shape[2] == 128, "Input must be (batch_size, 12, 128, 128)"
        assert x.shape[3] == 128, "Input must be (batch_size, 12, 128, 128)"

        # --
        # Preprocessing

        device = x.device
        x = x.cpu()

        tmp = [xx.numpy().transpose(1, 2, 0) for xx in x]
        tmp = [self.transform(xx) for xx in tmp]
        x = torch.stack(tmp)

        x = x.to(device)

        # --
        # Forward
        acts = self.encoder._forward_acts(x)
        out = self.encoder.rkhs_block_1(acts[self.encoder.dim2layer[1]])
        out = out[:, :, 0, 0]

        return out
Exemple #14
0
class MyModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)
        custom_config = model_config["custom_options"]
        latent_size = custom_config['latent_size']

        self.main = Encoder(latent_size=latent_size)

        if custom_config['encoder_path'] is not None:
            # saved checkpoints could contain extra weights such as linear_logsigma
            weights = torch.load(custom_config['encoder_path'],
                                 map_location=torch.device('cpu'))
            for k in list(weights.keys()):
                if k not in self.main.state_dict().keys():
                    del weights[k]
            self.main.load_state_dict(weights)
            print("Loaded Weights")
        else:
            print("No Load Weights")

        self.critic = nn.Sequential(nn.Linear(latent_size, 400), nn.ReLU(),
                                    nn.Linear(400, 300), nn.ReLU(),
                                    nn.Linear(300, 1))
        self.actor = nn.Sequential(nn.Linear(latent_size, 400), nn.ReLU(),
                                   nn.Linear(400, 300), nn.ReLU())
        self.alpha_head = nn.Sequential(nn.Linear(300, 3), nn.Softplus())
        self.beta_head = nn.Sequential(nn.Linear(300, 3), nn.Softplus())
        self._cur_value = None
        self.train_encoder = custom_config['train_encoder']
        print("Train Encoder: ", self.train_encoder)

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        features = self.main(input_dict['obs'].float())
        if not self.train_encoder:
            features = features.detach()  # not train the encoder

        actor_features = self.actor(features)
        alpha = self.alpha_head(actor_features) + 1
        beta = self.beta_head(actor_features) + 1
        logits = torch.cat([alpha, beta], dim=1)
        self._cur_value = self.critic(features).squeeze(1)

        return logits, state

    @override(TorchModelV2)
    def value_function(self):
        assert self._cur_value is not None, 'Must call forward() first'
        return self._cur_value
Exemple #15
0
class MyModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)
        custom_config = model_config['custom_options']

        self.main = Encoder()

        if custom_config['encoder_path'] is not None:
            print("Load Trained Encoder")
            # saved checkpoints could contain extra weights such as linear_logsigma
            weights = torch.load(custom_config['encoder_path'],
                                 map_location={'cuda:0': 'cpu'})
            for k in list(weights.keys()):
                if k not in self.main.state_dict().keys():
                    del weights[k]
            self.main.load_state_dict(weights)

        self.critic = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(),
                                    nn.Linear(1024, 256), nn.ReLU(),
                                    nn.Linear(256, 1))
        self.actor = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(),
                                   nn.Linear(1024, 256), nn.ReLU(),
                                   nn.Linear(256, 3), nn.Sigmoid())
        self.actor_logstd = nn.Parameter(torch.zeros(3), requires_grad=True)
        self._cur_value = None
        print("Train Encoder:", custom_config['train_encoder'])
        self.train_encoder = custom_config['train_encoder']

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        features = self.main(input_dict['obs'].float())
        if not self.train_encoder:
            features = features.detach()  # not train the encoder

        actor_mu = self.actor(features)  # Bx3
        batch_size = actor_mu.shape[0]
        actor_logstd = torch.stack(batch_size * [self.actor_logstd],
                                   dim=0)  # Bx3
        logits = torch.cat([actor_mu, actor_logstd], dim=1)
        self._cur_value = self.critic(features).squeeze(1)

        return logits, state

    @override(TorchModelV2)
    def value_function(self):
        assert self._cur_value is not None, 'Must call forward() first'
        return self._cur_value
def sample(img_path):
    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    vocab_size = len(vocab)
    print('vocab_size:', vocab_size)

    dataloader = get_loader(image_dir,
                            caption_path,
                            vocab,
                            batch_size,
                            crop_size,
                            shuffle=True,
                            num_workers=num_workers)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    img = Image.open(img_path)
    imgs = transform(img).to(device).unsqueeze(0)

    encoder = Encoder(embedding_size).to(device).eval()
    decoder = Decoder(attention_dim, embedding_size, lstm_size,
                      vocab_size).to(device).eval()

    encoder.load_state_dict(torch.load(encoder_path))
    decoder.load_state_dict(torch.load(decoder_path))

    with torch.no_grad(
    ):  # Avoid accumulating gradients which might result in out of memory
        features = encoder(imgs)
        captions = decoder.generate(features, vocab.word2idx['<sos>'])

    captions = captions.cpu().data.numpy()

    def translate(indices):
        sentences = list()
        for index in indices:
            word = vocab.idx2word[int(index)]
            if word == '<eos>':
                break
            sentences.append(word)
        return ' '.join(sentences)

    sentences = translate(captions[0])
    print(sentences)
Exemple #17
0
def main(args):
    random.seed(0)

    with open(args.config_path, 'r') as file:
        config = AttrDict(json.load(file))

    data_csv = pd.read_csv(config.test_data_csv_path)
    image_ids, labels = get_data(data_csv, is_train=False)
    dataset = CellsDataset(config.test_images_dir, image_ids)
    dataloader = DataLoader(dataset,
                            batch_size=config.batch_size,
                            shuffle=False,
                            num_workers=config.n_jobs)

    encoder = Encoder(config.n_image_channels, config.n_emedding_channels,
                      config.n_classes, config.encoder_model,
                      config.encoder_pretrained, config.encoder_dropout,
                      config.encoder_scale)

    if config.restore_checkpoint_path is not None:
        state_dict = torch.load(config.restore_checkpoint_path,
                                map_location='cpu')
        encoder.load_state_dict(state_dict)

    device = torch.device('cuda:0')
    encoder = encoder.half().to(device)
    encoder.eval()

    transforms = [
        torch_none, torch_rot90, torch_rot180, torch_rot270, torch_random_crop,
        torch_random_crop, torch_random_crop
    ]

    predicted = [[] for _ in range(len(transforms))]
    for images in tqdm(dataloader):
        for i, t in enumerate(transforms):
            transformed_images = t(images)
            transformed_images = transformed_images.half().to(device)
            log_probs = encoder(transformed_images)
            predicted[i].append(log_probs.float().cpu().numpy())

    predicted = [
        np.concatenate(predicted[i], axis=0) for i in range(len(transforms))
    ]
    make_submit(predicted, labels)
Exemple #18
0
def predict(image_name, model_path=None):
    print(len(data.dictionary))
    encoder = Encoder()
    decoder = DecoderWithAttention(len(data.dictionary))
    if cuda:
        encoder = encoder.cuda()
        decoder = decoder.cuda()
    if model_path:
        print('Loading the parameters of model.')
        if cuda:
            encoder.load_state_dict(torch.load(model_path[0]))
            decoder.load_state_dict(torch.load(model_path[1]))
        else:
            encoder.load_state_dict(
                torch.load(model_path[0], map_location='cpu'))
            decoder.load_state_dict(
                torch.load(model_path[1], map_location='cpu'))
    encoder.eval()
    decoder.eval()

    image = cv2.imread(image_name)
    image = cv2.resize(image, (224, 224))
    image = image.astype(np.float32) / 255.0
    image = image.transpose([2, 0, 1])
    image = np.expand_dims(image, axis=0)
    image = torch.from_numpy(image).type(torch.FloatTensor)
    if cuda:
        image = image.cuda()

    output = encoder(image)
    # print('encoder output:', output.size())
    sentences, alphas = beam_search(data, decoder, output)
    # print(sentences)
    show(image_name, sentences[0], alphas[0])

    for sentence in sentences:
        prediction = []
        for word in sentence:
            prediction.append(data.dictionary[word])
            if word == 2:
                break
        # print(prediction)
        prediction = ' '.join([word for word in prediction])
        print('The prediction sentence:', prediction)
def main(args):
    # Image preprocessing
    # In generation phase, we need should random crop, just resize
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wraper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build model
    encoder = Encoder(embed_size=args.embed_size).eval()
    decoder = Decoder(stateful=False,
                      embed_size=args.embed_size,
                      hidden_size=args.hidden_size,
                      vocab_size=len(vocab),
                      num_layers=args.num_layers).to(device)
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    # load the trained model parameters
    encoder.load_state_dict(torch.load(args.encoder_path, map_location=device))
    decoder.load_state_dict(torch.load(args.decoder_path, map_location=device))

    # Prepare an image
    image = load_image(args.image, transform)
    image_tensor = image.to(device)
    # Generate an caption from the image
    feature = encoder(image_tensor)
    sampled_ids = decoder.sample(feature)
    sampled_ids = sampled_ids[0].cpu().numpy()

    sampled_caption = []
    for word_id in sampled_ids:
        word = vocab.idx2word[word_id]
        sampled_caption.append(word)
        if word == '<<end>>':
            break
    sentence = ' '.join(sampled_caption)
    print(sentence)
Exemple #20
0
def initialize_for_test(params):
    data_loader = get_loader(params, mode='test')
    encoder_file = os.path.join(params.encoder_save,
                                'epoch-%d.pkl' % params.num_epochs)
    decoder_file = os.path.join(params.decoder_save,
                                'epoch-%d.pkl' % params.num_epochs)
    vocab_size = len(data_loader.dataset.vocab)

    # Initialize the encoder and decoder, and set each to inference mode.
    encoder = Encoder(params)
    decoder = Decoder(params, vocab_size)
    encoder.eval()
    decoder.eval()

    # Load the trained weights.
    encoder.load_state_dict(torch.load(encoder_file))
    decoder.load_state_dict(torch.load(decoder_file))
    encoder.to(params.device)
    decoder.to(params.device)
    return data_loader, encoder, decoder
Exemple #21
0
def instantiate_model(config, tokenizer):
    configure_devices(config)
    model = Model(config)
    optimizer = transformers.AdamW(model.parameters(),
                                   lr=config.learning_rate,
                                   weight_decay=0)
    last_epoch = 0
    epoch_avg_loss = 0
    if config.continue_training:
        state_dict = torch.load(config.continue_training, map_location='cpu')
        model.load_state_dict(state_dict['model'])
        if 'optimizer_state_dict' in state_dict:
            optimizer.load_state_dict(state_dict['optimizer_state_dict'])
        last_epoch = state_dict['epoch']
        # epoch_avg_loss = state_dict['loss']
        # del state_dict # TODO TEST
    if config.use_cuda:
        model = model.cuda()
        optimizer_to(optimizer, config.device)
        model = torch.nn.DataParallel(model, device_ids=config.devices)
    return model, optimizer, last_epoch, epoch_avg_loss
Exemple #22
0
def infer(opt):
    cuda = True if torch.cuda.is_available() else False
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # Dimensionality
    shared_dim = opt.dim * (2**opt.n_downsample)

    # Initialize generator and discriminator
    shared_E = ResidualBlock(in_channels=shared_dim)
    shared_G = ResidualBlock(in_channels=shared_dim)

    E1 = Encoder(dim=opt.dim,
                 n_downsample=opt.n_downsample,
                 shared_block=shared_E)
    G2 = Generator(dim=opt.dim,
                   n_upsample=opt.n_upsample,
                   shared_block=shared_G)

    shared_E.load_state_dict(
        torch.load(opt.load_model.replace('*', 'shared_E')))
    shared_G.load_state_dict(
        torch.load(opt.load_model.replace('*', 'shared_G')))
    E1.load_state_dict(torch.load(opt.load_model.replace('*', 'E1')))
    G2.load_state_dict(torch.load(opt.load_model.replace('*', 'G2')))

    if cuda:
        shared_E.cuda()
        shared_G.cuda()
        E1 = E1.cuda()
        G2 = G2.cuda()

    sample = load_img(opt)
    sample = Variable(sample.unsqueeze(0).type(FloatTensor))
    _, Z1 = E1(sample)
    fake_X2 = G2(Z1)

    sample = torch.cat((sample.data, fake_X2.data), -1)
    save_image(sample, "images/infer.png", nrow=1, normalize=True)
Exemple #23
0
def load_model(
    encoder_path,
    decoder_path,
    vocab_size,
    layer_type='gru',
    embed_size=256,
    hidden_size=512,
    num_layers=2,
):
    if layer_type == 'lstm':
        from model import Encoder, Decoder
    else:
        from model_gru import Encoder, Decoder

    # eval mode (batchnorm uses moving mean/variance)
    encoder = Encoder(embed_size).eval()
    decoder = Decoder(embed_size, hidden_size, vocab_size, num_layers)
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    # Load the trained model parameters
    encoder.load_state_dict(torch.load(encoder_path))
    decoder.load_state_dict(torch.load(decoder_path))
    return encoder, decoder
Exemple #24
0
def main(args):
    global batch_size
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    w_embed_size = args.w_embed_size
    lr = args.lr

    train_file = 'data/train_data_nv.txt'

    vocab = Vocab()
    vocab.build(train_file)

    if args.pre_trained_embed == 'n':
        encoder = Encoder(vocab.n_words, w_embed_size, hidden_size,
                          batch_size).to(device)
        decoder = AttentionDecoder(vocab.n_words, w_embed_size, hidden_size,
                                   batch_size).to(device)
    else:
        # load pre-trained embedding
        weight = vocab.load_weight(path="data/komoran_hd_2times.vec")
        encoder = Encoder(vocab.n_words, w_embed_size, hidden_size, batch_size,
                          weight).to(device)
        decoder = AttentionDecoder(vocab.n_words, w_embed_size, hidden_size,
                                   batch_size, weight).to(device)

    if args.encoder:
        encoder.load_state_dict(torch.load(args.encoder))
        print("[INFO] load encoder with %s" % args.encoder)
    if args.decoder:
        decoder.load_state_dict(torch.load(args.decoder))
        print("[INFO] load decoder with %s" % args.decoder)

    train_data = prep.read_train_data(train_file)
    train_loader = data.DataLoader(train_data,
                                   batch_size=batch_size,
                                   shuffle=True)

    # ev.evaluateRandomly(encoder, decoder, train_data, vocab, batch_size)
    # ev.evaluate_with_print(encoder, vocab, batch_size)

    # initialize
    max_a_at_5, max_a_at_1 = ev.evaluate_similarity(encoder,
                                                    vocab,
                                                    batch_size,
                                                    decoder=decoder)
    # max_a_at_5, max_a_at_1 = 0, 0
    max_bleu = 0

    total_epoch = args.epoch
    print(args)
    for epoch in range(1, total_epoch + 1):
        random.shuffle(train_data)
        trainIters(args,
                   epoch,
                   encoder,
                   decoder,
                   total_epoch,
                   train_data,
                   vocab,
                   train_loader,
                   print_every=2,
                   learning_rate=lr)

        if epoch % 20 == 0:
            a_at_5, a_at_1 = ev.evaluate_similarity(encoder,
                                                    vocab,
                                                    batch_size,
                                                    decoder=decoder)

            if a_at_1 > max_a_at_1:
                max_a_at_1 = a_at_1
                print("[INFO] New record! accuracy@1: %.4f" % a_at_1)

            if a_at_5 > max_a_at_5:
                max_a_at_5 = a_at_5
                print("[INFO] New record! accuracy@5: %.4f" % a_at_5)
                if args.save == 'y':
                    torch.save(encoder.state_dict(), 'encoder-max.model')
                    torch.save(decoder.state_dict(), 'decoder-max.model')
                    print("[INFO] new model saved")

            bleu = ev.evaluateRandomly(encoder, decoder, train_data, vocab,
                                       batch_size)
            if bleu > max_bleu:
                max_bleu = bleu
                if args.save == 'y':
                    torch.save(encoder.state_dict(), 'encoder-max-bleu.model')
                    torch.save(decoder.state_dict(), 'decoder-max-bleu.model')
                    print("[INFO] new model saved")

    print("Done! max accuracy@5: %.4f, max accuracy@1: %.4f" %
          (max_a_at_5, max_a_at_1))
    print("max bleu: %.2f" % max_bleu)
    if args.save == 'y':
        torch.save(encoder.state_dict(), 'encoder-last.model')
        torch.save(decoder.state_dict(), 'decoder-last.model')
def eval_reward(args, shared_model, writer_dir=None):
    """
	For evaluation

	Arguments:
	- writer: the tensorboard summary writer directory (note: can't get it working directly with the SummaryWriter object)
	"""
    writer = SummaryWriter(log_dir=os.path.join(
        writer_dir, 'eval')) if writer_dir is not None else None

    # current episode stats
    episode_reward = episode_value_mse = episode_td_error = episode_pg_loss = episode_length = 0

    # global stats
    i_episode = 0
    total_episode = total_steps = 0
    num_goals_achieved = 0

    # intilialize the env and models
    torch.manual_seed(args.seed)
    env = create_env(args.env_name, framework=args.framework, args=args)
    set_seed(args.seed, env, args.framework)

    shared_enc, shared_dec, shared_d_module, shared_r_module = shared_model

    enc = Encoder(env.observation_space.shape[0],
                  args.dim,
                  use_conv=args.use_conv)
    dec = Decoder(env.observation_space.shape[0],
                  args.dim,
                  use_conv=args.use_conv)
    d_module = D_Module(env.action_space.shape[0], args.dim, args.discrete)
    r_module = R_Module(env.action_space.shape[0],
                        args.dim,
                        discrete=args.discrete,
                        baseline=False,
                        state_space=env.observation_space.shape[0])

    all_params = chain(enc.parameters(), dec.parameters(),
                       d_module.parameters(), r_module.parameters())

    if args.from_checkpoint is not None:
        model_state, _ = torch.load(args.from_checkpoint)
        model.load_state_dict(model_state)

    # set the model to evaluation mode
    enc.eval()
    dec.eval()
    d_module.eval()
    r_module.eval()

    # reset the state
    state = env.reset()
    state = Variable(torch.from_numpy(state).float())

    start = time.time()

    while total_episode < args.num_episodes:

        # Sync with the shared model
        r_module.load_state_dict(shared_r_module.state_dict())
        d_module.load_state_dict(shared_d_module.state_dict())
        enc.load_state_dict(shared_enc.state_dict())
        dec.load_state_dict(shared_dec.state_dict())

        # reset stuff
        cd_p = Variable(torch.zeros(1, args.lstm_dim))
        hd_p = Variable(torch.zeros(1, args.lstm_dim))

        # for the reward
        cr_p = Variable(torch.zeros(1, args.lstm_dim))
        hr_p = Variable(torch.zeros(1, args.lstm_dim))

        i_episode += 1
        episode_length = 0
        episode_reward = 0
        args.local = True
        args.d = 0
        succ, _, episode_reward, episode_length = test(1, args, args, args,
                                                       d_module, r_module, enc)
        log("Eval: succ {:.2f}, reward {:.2f}, length {:.2f}".format(
            succ, episode_reward, episode_length))
        # Episode has ended, write the summaries here
        if writer_dir is not None:
            # current episode stats
            writer.add_scalar('eval/episode_reward', episode_reward, i_episode)
            writer.add_scalar('eval/episode_length', episode_length, i_episode)
            writer.add_scalar('eval/success', succ, i_episode)

        time.sleep(args.eval_every)
        print("sleep")
Exemple #26
0
def main():

    parser = argparse.ArgumentParser(
        description='Estimate average error and std for each MNIST dataset')
    parser.add_argument('--model-name',
                        type=str,
                        required=True,
                        help='filepath of model to use')
    parser.add_argument('--output-name',
                        type=str,
                        required=True,
                        help='name of output files')
    parser.add_argument('--batch-size',
                        type=int,
                        default=200,
                        metavar='N',
                        help='batch-size for evaluation')

    args = parser.parse_args()

    #Load model
    path = '/home/ubuntu/Saved_Models/'
    filename = os.path.join(path, args.model_name, 'checkpoint.pt')

    use_cuda = torch.cuda.is_available()
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    device = torch.device("cuda" if use_cuda else "cpu")

    model = Encoder(device)
    model.load_state_dict(torch.load(filename))
    model = model.cuda()

    data_root_file = '/home/ubuntu/mnist-interpretable-tranformations/data'
    data_loaders = {
        digit: DataLoader(MNISTDadataset(data_root_file, digit),
                          batch_size=args.batch_size,
                          shuffle=False,
                          **kwargs)
        for digit in range(0, 10)
    }

    step = 5  #degrees step
    mean_error = pd.DataFrame()
    mean_abs_error = pd.DataFrame()
    error_std = pd.DataFrame()

    for digit, data_loader in data_loaders.items():
        sys.stdout.write('Processing digit {} \n'.format(digit))
        sys.stdout.flush()
        results = get_metrics(model, data_loader, device, step)
        mean_error[digit] = pd.Series(results[0])
        mean_abs_error[digit] = pd.Series(results[1])
        error_std[digit] = pd.Series(results[2])

    mean_error.index = mean_error.index * step
    mean_abs_error.index = mean_abs_error.index * step
    error_std.index = error_std.index * step

    mean_error.to_csv(args.output_name + '_mean_error.csv')
    mean_abs_error.to_csv(args.output_name + '_mean_abs_error.csv')
    error_std.to_csv(args.output_name + '_error_std.csv')

    ##Plottin just absolute error
    with plt.style.context('ggplot'):
        mean_abs_error.plot(figsize=(9, 8))
        plt.xlabel('Degrees')
        plt.ylabel('Average error in degrees')
        plt.legend(loc="upper left",
                   bbox_to_anchor=[0, 1],
                   ncol=2,
                   shadow=True,
                   title="Digits",
                   fancybox=True)

        plt.tick_params(colors='gray', direction='out')
        plt.savefig(args.output_name + '_abs_mean_curves.png')
        plt.close()

    ##Plotting absoltue error and std
    with plt.style.context('ggplot'):
        fig = plt.figure(figsize=(9, 8))
        ax = fig.add_subplot(111)
        x = mean_abs_error.index
        for digit in mean_abs_error.columns:
            mean = mean_abs_error[digit]
            std = error_std[digit]
            line, = ax.plot(x, mean)
            ax.fill_between(x,
                            mean - std,
                            mean + std,
                            alpha=0.2,
                            facecolor=line.get_color(),
                            edgecolor=line.get_color())

        ax.set_xlabel('Degrees')
        ax.set_ylabel('Average error in degrees')
        ax.legend(loc="upper left",
                  bbox_to_anchor=[0, 1],
                  ncol=2,
                  shadow=True,
                  title="Digits",
                  fancybox=True)
        ax.tick_params(colors='gray', direction='out')
        fig.savefig(args.output_name + '_mean_&_std_curves.png')
        fig.clf()
Exemple #27
0
class ALADTrainer:
    def __init__(self, args, data, device):
        self.args = args
        self.train_loader, _ = data
        self.device = device
        self.build_models()

    def train(self):
        """Training the ALAD"""

        if self.args.pretrained:
            self.load_weights()

        optimizer_ge = optim.Adam(list(self.G.parameters()) +
                                  list(self.E.parameters()),
                                  lr=self.args.lr,
                                  betas=(0.5, 0.999))
        params_ = list(self.Dxz.parameters()) \
                + list(self.Dzz.parameters()) \
                + list(self.Dxx.parameters())
        optimizer_d = optim.Adam(params_, lr=self.args.lr, betas=(0.5, 0.999))

        fixed_z = Variable(torch.randn((16, self.args.latent_dim, 1, 1)),
                           requires_grad=False).to(self.device)
        criterion = nn.BCELoss()
        for epoch in range(self.args.num_epochs + 1):
            ge_losses = 0
            d_losses = 0
            for x, _ in Bar(self.train_loader):
                #Defining labels
                y_true = Variable(torch.ones((x.size(0), 1)).to(self.device))
                y_fake = Variable(torch.zeros((x.size(0), 1)).to(self.device))

                #Cleaning gradients.
                optimizer_d.zero_grad()
                optimizer_ge.zero_grad()

                #Generator:
                z_real = Variable(torch.randn(
                    (x.size(0), self.args.latent_dim, 1, 1)).to(self.device),
                                  requires_grad=False)
                x_gen = self.G(z_real)

                #Encoder:
                x_real = x.float().to(self.device)
                z_gen = self.E(x_real)

                #Discriminatorxz
                out_truexz, _ = self.Dxz(x_real, z_gen)
                out_fakexz, _ = self.Dxz(x_gen, z_real)

                #Discriminatorzz
                out_truezz, _ = self.Dzz(z_real, z_real)
                out_fakezz, _ = self.Dzz(z_real, self.E(self.G(z_real)))

                #Discriminatorxx
                out_truexx, _ = self.Dxx(x_real, x_real)
                out_fakexx, _ = self.Dxx(x_real, self.G(self.E(x_real)))

                #Losses
                loss_dxz = criterion(out_truexz, y_true) + criterion(
                    out_fakexz, y_fake)
                loss_dzz = criterion(out_truezz, y_true) + criterion(
                    out_fakezz, y_fake)
                loss_dxx = criterion(out_truexx, y_true) + criterion(
                    out_fakexx, y_fake)
                loss_d = loss_dxz + loss_dzz + loss_dxx

                loss_gexz = criterion(out_fakexz, y_true) + criterion(
                    out_truexz, y_fake)
                loss_gezz = criterion(out_fakezz, y_true) + criterion(
                    out_truezz, y_fake)
                loss_gexx = criterion(out_fakexx, y_true) + criterion(
                    out_truexx, y_fake)
                cycle_consistency = loss_gezz + loss_gexx
                loss_ge = loss_gexz + loss_gezz + loss_gexx  # + cycle_consistency
                #Computing gradients and backpropagate.
                loss_d.backward(retain_graph=True)
                loss_ge.backward()
                optimizer_d.step()
                optimizer_ge.step()

                d_losses += loss_d.item()

                ge_losses += loss_ge.item()

            if epoch % 10 == 0:
                vutils.save_image((self.G(fixed_z).data + 1) / 2.,
                                  './images/{}_fake.png'.format(epoch))

            print(
                "Training... Epoch: {}, Discrimiantor Loss: {:.3f}, Generator Loss: {:.3f}"
                .format(epoch, d_losses / len(self.train_loader),
                        ge_losses / len(self.train_loader)))
        self.save_weights()

    def build_models(self):
        self.G = Generator(self.args.latent_dim).to(self.device)
        self.E = Encoder(self.args.latent_dim,
                         self.args.spec_norm).to(self.device)
        self.Dxz = Discriminatorxz(self.args.latent_dim,
                                   self.args.spec_norm).to(self.device)
        self.Dxx = Discriminatorxx(self.args.spec_norm).to(self.device)
        self.Dzz = Discriminatorzz(self.args.latent_dim,
                                   self.args.spec_norm).to(self.device)
        self.G.apply(weights_init_normal)
        self.E.apply(weights_init_normal)
        self.Dxz.apply(weights_init_normal)
        self.Dxx.apply(weights_init_normal)
        self.Dzz.apply(weights_init_normal)

    def save_weights(self):
        """Save weights."""
        state_dict_Dxz = self.Dxz.state_dict()
        state_dict_Dxx = self.Dxx.state_dict()
        state_dict_Dzz = self.Dzz.state_dict()
        state_dict_E = self.E.state_dict()
        state_dict_G = self.G.state_dict()
        torch.save(
            {
                'Generator': state_dict_G,
                'Encoder': state_dict_E,
                'Discriminatorxz': state_dict_Dxz,
                'Discriminatorxx': state_dict_Dxx,
                'Discriminatorzz': state_dict_Dzz
            },
            'weights/model_parameters_{}.pth'.format(self.args.normal_class))

    def load_weights(self):
        """Load weights."""
        state_dict = torch.load('weights/model_parameters.pth')

        self.Dxz.load_state_dict(state_dict['Discriminatorxz'])
        self.Dxx.load_state_dict(state_dict['Discriminatorxx'])
        self.Dzz.load_state_dict(state_dict['Discriminatorzz'])
        self.G.load_state_dict(state_dict['Generator'])
        self.E.load_state_dict(state_dict['Encoder'])
Exemple #28
0
# Check input arguments validation
for path in glob.glob(args.img_glob):
    assert os.path.isfile(path), '%s not found' % path
for path in glob.glob(args.line_glob):
    assert os.path.isfile(path), '%s not found' % path
assert os.path.isdir(
    args.output_dir), '%s is not a directory' % args.output_dir
assert 0 <= args.alpha and args.alpha <= 1, '--arpha should in [0, 1]'
for rotate in args.rotate:
    assert 0 <= rotate and rotate <= 1, 'elements in --rotate should in [0, 1]'

# Prepare model
encoder = Encoder().to(device)
edg_decoder = Decoder(skip_num=2, out_planes=3).to(device)
cor_decoder = Decoder(skip_num=3, out_planes=1).to(device)
encoder.load_state_dict(torch.load('%s_encoder.pth' % args.path_prefix))
edg_decoder.load_state_dict(torch.load('%s_edg_decoder.pth' %
                                       args.path_prefix))
cor_decoder.load_state_dict(torch.load('%s_cor_decoder.pth' %
                                       args.path_prefix))

# Load path to visualization
img_paths = sorted(glob.glob(args.img_glob))
line_paths = sorted(glob.glob(args.line_glob))
assert len(img_paths) == len(
    line_paths), '# of input mismatch for each channels'


def augment(x_img):
    aug_type = ['']
    x_imgs_augmented = [x_img]
    if not args.no_ema:
        e_ema.eval()
        accumulate(e_ema, encoder, 0)

    if args.use_latent_teacher_forcing:
        # encoder that predicts w
        e_tf = Encoder(args.size,
                       args.latent,
                       channel_multiplier=args.channel_multiplier,
                       which_latent=args.which_latent,
                       stddev_group=args.stddev_group,
                       reparameterization=False).to(device)
        e_tf.eval()
        ckpt = torch.load(args.etf_ckpt,
                          map_location=lambda storage, loc: storage)
        e_tf.load_state_dict(ckpt["e_ema"])
    else:
        e_tf = None

    # For lazy regularization (see paper appendix page 11)
    e_reg_ratio = args.e_reg_every / (args.e_reg_every +
                                      1) if args.e_reg_every > 0 else 1.
    d_reg_ratio = args.d_reg_every / (args.d_reg_every +
                                      1) if args.d_reg_every > 0 else 1.

    e_optim = optim.Adam(
        encoder.parameters(),
        lr=args.lr * e_reg_ratio,
        betas=(0**e_reg_ratio, 0.99**e_reg_ratio),
    )
    d_optim = optim.Adam(
Exemple #30
0
def main(args):

    # ==============================
    # Create some folders or files for saving
    # ==============================

    if not os.path.exists(args.root_folder):
        os.mkdir(args.root_folder)

    loss_path = args.loss_path
    mertics_path = args.mertics_path
    epoch_model_path = args.epoch_model_path
    best_model_path = args.best_model_path
    generated_captions_path = args.generated_captions_folder_path
    sentences_show_path = args.sentences_show_path

    # Transform the format of images
    # This function in utils.general_tools.py
    train_transform = get_train_transform()
    val_transform = get_val_trainsform()

    # Load vocabulary
    print("*** Load Vocabulary ***")
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Create data sets
    # This function in data_load.py
    train_data = train_load(root=args.train_image_dir,
                            json=args.train_caption_path,
                            vocab=vocab,
                            transform=train_transform,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers)

    val_data = val_load(root=args.val_image_dir,
                        json=args.val_caption_path,
                        transform=val_transform,
                        batch_size=1,
                        shuffle=False,
                        num_workers=args.num_workers)

    # Build model
    encoder = Encoder(args.hidden_dim, args.fine_tuning).to(device)
    decoder = Decoder(args.embedding_dim, args.hidden_dim, vocab, len(vocab),
                      args.max_seq_length).to(device)

    # Select loss function
    criterion = nn.CrossEntropyLoss().to(device)

    if args.fine_tuning == True:
        params = list(decoder.parameters()) + list(encoder.parameters())
        optimizer = torch.optim.Adam(params, lr=args.fine_tuning_lr)
    else:
        params = decoder.parameters()
        optimizer = torch.optim.Adam(params, lr=args.fine_tuning_lr)

    # Load pretrained model
    if args.resume == True:
        checkpoint = torch.load(best_model_path)
        encoder.load_state_dict(checkpoint['encoder'])
        decoder.load_state_dict(checkpoint['decoder'])
        if args.fine_tuning == False:
            optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch'] + 1
        best_score = checkpoint['best_score']
        best_epoch = checkpoint['best_epoch']

    # New epoch and score
    else:
        start_epoch = 1
        best_score = 0
        best_epoch = 0

    for epoch in range(start_epoch, 10000):

        print("-" * 20)
        print("epoch:{}".format(epoch))

        # Adjust learning rate when the difference between epoch and best epoch is multiple of 3
        if (epoch - best_epoch) > 0 and (epoch - best_epoch) % 4 == 0:
            # This function in utils.general_tools.py
            adjust_lr(optimizer, args.shrink_factor)
        if (epoch - best_epoch) > 10:
            break
            print("*** Training complete ***")

        # =============
        # Training
        # =============

        print(" *** Training ***")
        decoder.train()
        encoder.train()
        total_step = len(train_data)
        epoch_loss = 0
        for (images, captions, lengths, img_ids) in tqdm(train_data):
            images = images.to(device)
            captions = captions.to(device)
            # Why do lengths cut 1 and the first dimension of captions from 1
            # Because we need to ignore the begining symbol <start>
            lengths = list(np.array(lengths) - 1)

            targets = pack_padded_sequence(captions[:, 1:],
                                           lengths,
                                           batch_first=True)[0]
            features = encoder(images)
            predictions = decoder(features, captions, lengths)
            predictions = pack_padded_sequence(predictions,
                                               lengths,
                                               batch_first=True)[0]

            loss = criterion(predictions, targets)
            epoch_loss += loss.item()
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()

        # Save loss information
        # This function in utils.save_tools.py
        save_loss(round(epoch_loss / total_step, 3), epoch, loss_path)

        # =============
        # Evaluating
        # =============

        print("*** Evaluating ***")
        encoder.eval()
        decoder.eval()
        generated_captions = []
        for image, img_id in tqdm(val_data):

            image = image.to(device)
            img_id = img_id[0]

            features = encoder(image)
            sentence = decoder.generate(features)
            sentence = ' '.join(sentence)
            item = {'image_id': int(img_id), 'caption': sentence}
            generated_captions.append(item)
            j = random.randint(1, 100)

        print('*** Computing metrics ***')

        # Save current generated captions
        # This function in utils.save_tools.py

        captions_json_path = save_generated_captions(generated_captions, epoch,
                                                     generated_captions_path,
                                                     args.fine_tuning)

        # Compute score of metrics
        # This function in utils.general_tools.py
        results = coco_metrics(args.val_caption_path, captions_json_path,
                               epoch, sentences_show_path)

        # Save metrics results
        # This function in utils.save_tools.py
        epoch_score = save_metrics(results, epoch, mertics_path)

        # Update the best score
        if best_score < epoch_score:

            best_score = epoch_score
            best_epoch = epoch

            save_best_model(encoder, decoder, optimizer, epoch, best_score,
                            best_epoch, best_model_path)

        print("*** Best score:{} Best epoch:{} ***".format(
            best_score, best_epoch))
        # Save every epoch model
        save_epoch_model(encoder, decoder, optimizer, epoch, best_score,
                         best_epoch, epoch_model_path, args.fine_tuning)