Пример #1
0
def main():
    parser = argparse.ArgumentParser(description='Implementation of SimCLR')
    parser.add_argument('--EPOCHS',
                        default=10,
                        type=int,
                        help='Number of epochs for training')
    parser.add_argument('--BATCH_SIZE',
                        default=64,
                        type=int,
                        help='Batch size')
    parser.add_argument('--TEMP',
                        default=0.5,
                        type=float,
                        help='Temperature parameter for NT-Xent')
    parser.add_argument(
        '--LOG_INT',
        default=100,
        type=int,
        help='How many batches to wait before logging training status')
    parser.add_argument('--DISTORT_STRENGTH',
                        default=0.5,
                        type=float,
                        help='Strength of colour distortion')
    parser.add_argument('--SAVE_NAME', default='model')
    args = parser.parse_args()
    use_cuda = torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")

    online_transform = transforms.Compose([
        transforms.RandomResizedCrop((32, 32)),
        transforms.RandomHorizontalFlip(),
        get_color_distortion(s=args.DISTORT_STRENGTH),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    trainset = CIFAR10_new(root='./data',
                           train=True,
                           download=True,
                           transform=online_transform)

    # Need to drop last minibatch to prevent matrix multiplication erros
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.BATCH_SIZE,
                                               shuffle=True,
                                               drop_last=True)

    model = Encoder().to(device)
    optimizer = optim.Adam(model.parameters())
    loss_func = losses.NTXentLoss(args.TEMP)
    for epoch in range(args.EPOCHS):
        train(args, model, device, train_loader, optimizer, loss_func, epoch)

    torch.save(model.state_dict(), './ckpt/{}.pth'.format(args.SAVE_NAME))
Пример #2
0
def pretrain(source_data_loader,
             test_data_loader,
             no_classes,
             embeddings,
             epochs=20,
             batch_size=128,
             cuda=False):

    classifier = Classifier()
    encoder = Encoder(embeddings)

    if cuda:
        classifier.cuda()
        encoder.cuda()
    ''' Jointly optimize both encoder and classifier '''
    encoder_params = filter(lambda p: p.requires_grad, encoder.parameters())
    optimizer = optim.Adam(
        list(encoder_params) + list(classifier.parameters()))

    # Use weights to normalize imbalanced in data
    c = [1] * len(no_classes)
    weights = torch.FloatTensor(len(no_classes))
    for i, (a, b) in enumerate(zip(c, no_classes)):
        weights[i] = 0 if b == 0 else a / b

    loss_fn = nn.CrossEntropyLoss(weight=Variable(weights))

    print('Training encoder and classifier')
    for e in range(epochs):

        # pretrain with whole source data -- use groups with DCD
        for sample in source_data_loader:
            x, y = Variable(sample[0]), Variable(sample[1])
            optimizer.zero_grad()

            if cuda:
                x, y = x.cuda(), y.cuda()

            output = model_fn(encoder, classifier)(x)

            loss = loss_fn(output, y)

            loss.backward()

            optimizer.step()

        print("Epoch", e, "Loss", loss.data[0], "Accuracy",
              eval_on_test(test_data_loader, model_fn(encoder, classifier)))

    return encoder, classifier
Пример #3
0
def main(imgurl):
    # Load word map (word2ix)
    with open('input_files/WORDMAP.json', 'r') as j:
        word_map = json.load(j)
    rev_word_map = {v: k for k, v in word_map.items()}  # ix2word

    # Load model
    decoder = DecoderWithAttention(attention_dim=attention_dim,
                                   embed_dim=emb_dim,
                                   decoder_dim=decoder_dim,
                                   vocab_size=len(word_map),
                                   dropout=dropout)
    decoder_optimizer = torch.optim.Adam(params=filter(
        lambda p: p.requires_grad, decoder.parameters()),
                                         lr=decoder_lr)
    encoder = Encoder()
    encoder.fine_tune(fine_tune_encoder)
    encoder_optimizer = torch.optim.Adam(
        params=filter(lambda p: p.requires_grad, encoder.parameters()),
        lr=encoder_lr) if fine_tune_encoder else None

    decoder.load_state_dict(
        torch.load('output_files/BEST_checkpoint_decoder.pth.tar'))
    encoder.load_state_dict(
        torch.load('output_files/BEST_checkpoint_encoder.pth.tar'))

    decoder = decoder.to(device)
    decoder.eval()
    encoder = encoder.to(device)
    encoder.eval()

    # Encode, decode with attention and beam search
    seq, alphas = caption_image_beam_search(encoder,
                                            decoder,
                                            imgurl,
                                            word_map,
                                            beam_size=5)
    alphas = torch.FloatTensor(alphas)

    # Visualize caption and attention of best sequence
    # visualize_att(img, seq, alphas, rev_word_map, args.smooth)

    words = [rev_word_map[ind] for ind in seq]
    caption = ' '.join(words[1:-1])
    visualize_att(imgurl, seq, alphas, rev_word_map)
Пример #4
0
    def init_encoders(self):
        """
        Override to add your own encoders
        """

        encoder_q = Encoder(input_dim=self.hparams.input_dim,
                            hidden_dim=self.hparams.hidden_dim,
                            bidirectional=self.hparams.bidirectional,
                            embedding=self.hparams.input_embedding,
                            cell=self.hparams.cell,
                            num_layers=self.hparams.num_layers)

        encoder_k = Encoder(input_dim=self.hparams.input_dim,
                            hidden_dim=self.hparams.hidden_dim,
                            bidirectional=self.hparams.bidirectional,
                            embedding=self.hparams.input_embedding,
                            cell=self.hparams.cell,
                            num_layers=self.hparams.num_layers)

        decoder_q = Decoder(input_dim=self.hparams.hidden_dim,
                            hidden_dim=self.hparams.hidden_dim,
                            output_dim=self.hparams.input_dim,
                            bidirectional=self.hparams.bidirectional,
                            cell=self.hparams.cell,
                            num_layers=self.hparams.num_layers)

        decoder_k = Decoder(input_dim=self.hparams.hidden_dim,
                            hidden_dim=self.hparams.hidden_dim,
                            output_dim=self.hparams.input_dim,
                            bidirectional=self.hparams.bidirectional,
                            cell=self.hparams.cell,
                            num_layers=self.hparams.num_layers)

        for param in list(encoder_q.parameters()) + list(
                decoder_k.parameters()):
            if param.dim() == 2:
                nn.init.xavier_uniform_(param)

        return encoder_q, encoder_k, decoder_q, decoder_k
Пример #5
0
def pretrain(data, epochs=5, batch_size=128, cuda=False):

    X_s, y_s, _, _ = data

    test_dataloader = mnist_dataloader(train=False, cuda=cuda)

    classifier = Classifier()
    encoder = Encoder()

    if cuda:
        classifier.cuda()
        encoder.cuda()

    ''' Jointly optimize both encoder and classifier ''' 
    optimizer = optim.Adam(list(encoder.parameters()) + list(classifier.parameters()))
    loss_fn = nn.CrossEntropyLoss()
    
    for e in range(epochs):
        
        for _ in range(len(X_s) // batch_size):
            inds = torch.randperm(len(X_s))[:batch_size]

            x, y = Variable(X_s[inds]), Variable(y_s[inds])
            optimizer.zero_grad()

            if cuda:
                x, y = x.cuda(), y.cuda()

            y_pred = model_fn(encoder, classifier)(x)

            loss = loss_fn(y_pred, y)

            loss.backward()

            optimizer.step()

        print("Epoch", e, "Loss", loss.data[0], "Accuracy", eval_on_test(test_dataloader, model_fn(encoder, classifier)))
    
    return encoder, classifier
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map, rev_word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    rev_word_map = {v: k for k, v in word_map.items()}

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        pretrained_embs, pretrained_embs_dim = load_embeddings(
            '/home/Iwamura/datasets/datasets/GloVe/glove.6B.300d.txt',
            word_map)
        assert pretrained_embs_dim == decoder.embed_dim
        decoder.load_pretrained_embeddings(pretrained_embs)
        decoder.fine_tune_embeddings(True)

        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder_opt = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_opt.fine_tune(fine_tune_encoder_opt)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None
        encoder_optimizer_opt = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder_opt.parameters()),
            lr=encoder_opt_lr) if fine_tune_encoder_opt else None

    else:

        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder_opt = checkpoint['encoder_opt']
        encoder_optimizer_opt = checkpoint['encoder_optimizer_opt']

        # if fine_tune_encoder is True and encoder_optimizer is None and encoder_optimizer_opt is None
        if fine_tune_encoder_opt is True and encoder_optimizer_opt is None:
            encoder_opt.fine_tune(fine_tune_encoder_opt)

            encoder_optimizer_opt = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder_opt.parameters()),
                                                     lr=encoder_opt_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)

    encoder_opt = encoder_opt.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders

    normalize_opt = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize_opt])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize_opt])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 10:
            break
        if epoch > 0 and epoch % 4 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)

            if fine_tune_encoder_opt:
                adjust_learning_rate(encoder_optimizer_opt, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder_opt=encoder_opt,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer_opt=encoder_optimizer_opt,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder_opt=encoder_opt,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement,
                        encoder_opt, decoder, encoder_optimizer_opt,
                        decoder_optimizer, recent_bleu4, is_best)
Пример #7
0
from torch.utils.tensorboard import SummaryWriter

from models import Encoder, Decoder, Discriminator
from preprocessing import get_loader, inv_standardize
from utils import (Collector, reconstruction_loss_func, norm22)
from config import (knobs, log_dir_local_time, log_dir_last_modified,
                    checkpoints_dir_local_time, checkpoints_dir_last_modified,
                    interpolations_dir)

loader = get_loader()

encoder = Encoder().to(knobs["device"])
decoder = Decoder().to(knobs["device"])
discriminator = Discriminator().to(knobs["device"])

opt_encoder = torch.optim.Adam(encoder.parameters(), lr=knobs["lr_encoder"])
opt_decoder = torch.optim.Adam(decoder.parameters(), lr=knobs["lr_decoder"])
opt_discriminator = torch.optim.Adam(discriminator.parameters(),
                                     lr=knobs["lr_discriminator"])

collector_reconstruction_loss = Collector()
collector_fooling_term = Collector()
collector_error_discriminator = Collector()
collector_heuristic_discriminator = Collector()
collector_codes_min = Collector()
collector_codes_max = Collector()
if knobs["resume"]:
    writer = SummaryWriter(log_dir_last_modified)
    checkpoint_dir = checkpoints_dir_last_modified
    checkpoint = torch.load(checkpoint_dir)
    starting_epoch = checkpoint["epoch"]
Пример #8
0
                      allow_pickle=True)

embedding_dim = 200
hidden_dim = 192
BATCH_NUM = 10
EPOCH_NUM = 30
vocab_size = len(id2word)
device = 'cpu'

encoder = Encoder(vocab_size, embedding_dim, hidden_dim).to(device)
attn_decoder = AttentionDecoder(vocab_size, embedding_dim, hidden_dim,
                                BATCH_NUM).to(device)
# 損失関数
criterion = nn.CrossEntropyLoss()
# 最適化
encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.001)
attn_decoder_optimizer = optim.Adam(attn_decoder.parameters(), lr=0.001)

all_losses = []


def main():
    for epoch in range(1, EPOCH_NUM + 1):
        epoch_loss = 0

        input_batch, output_batch = train2batch(input_data,
                                                output_data,
                                                batch_size=BATCH_NUM)
        for i in range(len(input_batch)):

            encoder_optimizer.zero_grad()
Пример #9
0
def main():
    """
    Training and validation.
    """

    global best_bleu4, use_amp, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    #use_amp = True
    #print("Using amp for mized precision training")

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                             lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    # use mixed precision training using Nvidia Apex
    if use_amp:
        decoder, decoder_optimizer = amp.initialize(
                                    decoder, decoder_optimizer, opt_level="O2",
                                    keep_batchnorm_fp32=True, loss_scale="dynamic")
    encoder = encoder.to(device)
    if not encoder_optimizer:
        print("Encoder is not being optimized")
    elif use_amp:
        encoder, encoder_optimizer = amp.initialize(
                                    encoder, encoder_optimizer, opt_level="O2",
                                    keep_batchnorm_fp32=True, loss_scale="dynamic")

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                        decoder_optimizer, recent_bleu4, is_best)
Пример #10
0
def main(args):
    """
    Training and validation.
    """
    with open(args.word_map_file, 'rb') as f:
        word_map = pickle.load(f)


    #make choice wich ecoder to use
    encoder = Encoder(input_feature_dim=args.input_feature_dim,
                         encoder_hidden_dim=args.encoder_hidden_dim,
                         encoder_layer=args.encoder_layer,
                         rnn_unit=args.rnn_unit,
                         use_gpu=args.CUDA,
                         dropout_rate=args.dropout_rate
                         )

    #encoder = EncoderT(input_feature_dim=args.input_feature_dim,
     #                    encoder_hidden_dim=args.encoder_hidden_dim,
     #                    encoder_layer=args.encoder_layer,
     #                   rnn_unit=args.rnn_unit,
     #                    use_gpu=args.CUDA,
     #                    dropout_rate=args.dropout,
     #                    nhead=args.nhead
     #                    )

    decoder = DecoderWithAttention(attention_dim=args.attention_dim,
                                    embed_dim=args.emb_dim,
                                    decoder_dim=args.decoder_dim,
                                    vocab_size=len(word_map),
                                    dropout=args.dropout)
                                    
    if args.resume:
        encoder.load_state_dict(torch.load(args.encoder_path))
        decoder.load_state_dict(torch.load(args.decoder_path))


    encoder_parameter = [p for p in encoder.parameters() if p.requires_grad] # selecting every parameter.
    decoder_parameter = [p for p in decoder.parameters() if p.requires_grad]
   
    encoder_optimizer = torch.optim.Adam(encoder_parameter,lr=args.decoder_lr) #Adam selected
    decoder_optimizer = torch.optim.Adam(decoder_parameter,lr=args.decoder_lr)
    
    if args.CUDA:
        decoder = decoder.cuda()    
        encoder = encoder.cuda()

    if args.CUDA:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        criterion = nn.CrossEntropyLoss() #gewoon naar cpu dan
    
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(args.data_path, split='TRAIN'),
        batch_size=args.batch_size, shuffle=True, num_workers=args.workers, collate_fn=pad_collate_train,pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(args.data_path,split='VAL'),
        batch_size=args.batch_size, shuffle=False, num_workers=args.workers,collate_fn=pad_collate_train, pin_memory=True)

    # Epochs
    best_bleu4 = 0
    for epoch in range(args.start_epoch, args.epochs):
        
        losst = train(train_loader=train_loader,  ## deze los is de trainit_weight loss! 
              encoder = encoder,              
              decoder=decoder,
              criterion=criterion,  
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,           
              epoch=epoch,args=args)

        # One epoch's validation
        if epoch%1==0:
            lossv = validate(val_loader=val_loader,   
                            encoder=encoder,
                            decoder=decoder,
                            criterion=criterion,
                            best_bleu=best_bleu4,
                            args=args) 

        info = 'LOSST - {losst:.4f}, LOSSv - {lossv:.4f}\n'.format(
                losst=losst,
                lossv=lossv)


        with open(dev, "a") as f:    ## de los moet ook voor de validation 
            f.write(info)
            f.write("\n")  

        #Selecteren op basis van Bleu gaat als volgt:    
        #print('BLEU4: ' + bleu4)
        #print('best_bleu4 '+ best_bleu4)
        #if bleu4>best_bleu4:
        if epoch %3 ==0:
            save_checkpoint(epoch, encoder, decoder, encoder_optimizer,
                            decoder_optimizer, lossv)
Пример #11
0
      action = env.sample_random_action()
      next_observation, reward, done = env.step(action)
      # print(next_observation.shape, '   ############')
      D.append(observation, action, reward, done)
      observation = next_observation
      t += 1
    metrics['steps'].append(t * args.action_repeat + (0 if len(metrics['steps']) == 0 else metrics['steps'][-1]))
    metrics['episodes'].append(s)


# Initialise model parameters randomly
transition_model = TransitionModel(args.belief_size, args.state_size, env.action_size, args.hidden_size, args.embedding_size, args.activation_function).to(device=args.device)
observation_model = ObservationModel(args.symbolic_env, env.observation_size, args.belief_size, args.state_size, args.embedding_size, args.activation_function).to(device=args.device)
reward_model = RewardModel(args.belief_size, args.state_size, args.hidden_size, args.activation_function).to(device=args.device)
encoder = Encoder(args.symbolic_env, env.observation_size, args.embedding_size, args.activation_function).to(device=args.device)
param_list = list(transition_model.parameters()) + list(observation_model.parameters()) + list(reward_model.parameters()) + list(encoder.parameters())
optimiser = optim.Adam(param_list, lr=0 if args.learning_rate_schedule != 0 else args.learning_rate, eps=args.adam_epsilon)
if args.models is not '' and os.path.exists(args.models):
  model_dicts = torch.load(args.models)
  transition_model.load_state_dict(model_dicts['transition_model'])
  observation_model.load_state_dict(model_dicts['observation_model'])
  reward_model.load_state_dict(model_dicts['reward_model'])
  encoder.load_state_dict(model_dicts['encoder'])
  optimiser.load_state_dict(model_dicts['optimiser'])
planner = MPCPlanner(env.action_size, args.planning_horizon, args.optimisation_iters, args.candidates, args.top_candidates, transition_model, reward_model)
global_prior = Normal(torch.zeros(args.batch_size, args.state_size, device=args.device), torch.ones(args.batch_size, args.state_size, device=args.device))  # Global prior N(0, I)
free_nats = torch.full((1, ), args.free_nats, device=args.device)  # Allowed deviation in KL divergence


def update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation, test):
  # Infer belief over current state q(s_t|o≤t,a<t) from the history
Пример #12
0
def train_model(cfg):
    tensorboard_path = Path(utils.to_absolute_path("tensorboard")) / cfg.checkpoint_dir
    checkpoint_dir = Path(utils.to_absolute_path(cfg.checkpoint_dir))
    writer = SummaryWriter(tensorboard_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    encoder = Encoder(**cfg.model.encoder)
    decoder = Decoder(**cfg.model.decoder)
    encoder.to(device)
    decoder.to(device)

    optimizer = optim.Adam(
        chain(encoder.parameters(), decoder.parameters()),
        lr=cfg.training.optimizer.lr)
    [encoder, decoder], optimizer = amp.initialize([encoder, decoder], optimizer, opt_level="O1")
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=cfg.training.scheduler.milestones,
        gamma=cfg.training.scheduler.gamma)

    if cfg.resume:
        print("Resume checkpoint from: {}:".format(cfg.resume))
        resume_path = utils.to_absolute_path(cfg.resume)
        checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage)
        encoder.load_state_dict(checkpoint["encoder"])
        decoder.load_state_dict(checkpoint["decoder"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        amp.load_state_dict(checkpoint["amp"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        global_step = checkpoint["step"]
    else:
        global_step = 0

    root_path = Path(utils.to_absolute_path("datasets")) / cfg.dataset.path
    dataset = SpeechDataset(
        root=root_path,
        hop_length=cfg.preprocessing.hop_length,
        sr=cfg.preprocessing.sr,
        sample_frames=cfg.training.sample_frames)

    dataloader = DataLoader(
        dataset,
        batch_size=cfg.training.batch_size,
        shuffle=True,
        num_workers=cfg.training.n_workers,
        pin_memory=True,
        drop_last=True)

    n_epochs = cfg.training.n_steps // len(dataloader) + 1
    start_epoch = global_step // len(dataloader) + 1

    for epoch in range(start_epoch, n_epochs + 1):
        average_recon_loss = average_vq_loss = average_perplexity = 0

        for i, (audio, mels, speakers) in enumerate(tqdm(dataloader), 1):
            audio, mels, speakers = audio.to(device), mels.to(device), speakers.to(device)

            optimizer.zero_grad()

            z, vq_loss, perplexity = encoder(mels)
            output = decoder(audio[:, :-1], z, speakers)
            recon_loss = F.cross_entropy(output.transpose(1, 2), audio[:, 1:])
            loss = recon_loss + vq_loss

            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()

            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
            optimizer.step()
            scheduler.step()

            average_recon_loss += (recon_loss.item() - average_recon_loss) / i
            average_vq_loss += (vq_loss.item() - average_vq_loss) / i
            average_perplexity += (perplexity.item() - average_perplexity) / i

            global_step += 1

            if global_step % cfg.training.checkpoint_interval == 0:
                save_checkpoint(
                    encoder, decoder, optimizer, amp,
                    scheduler, global_step, checkpoint_dir)

        writer.add_scalar("recon_loss/train", average_recon_loss, global_step)
        writer.add_scalar("vq_loss/train", average_vq_loss, global_step)
        writer.add_scalar("average_perplexity", average_perplexity, global_step)

        print("epoch:{}, recon loss:{:.2E}, vq loss:{:.2E}, perpexlity:{:.3f}"
              .format(epoch, average_recon_loss, average_vq_loss, average_perplexity))
Пример #13
0
def main(checkpoint, tienet):
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, start_epoch, fine_tune_encoder, data_name, word_map

    if checkpoint:
        dest_dir = checkpoint
        checkpoint = os.path.join(
            dest_dir,
            'checkpoint_mimiccxr_1_cap_per_img_5_min_word_freq.pth.tar'
        )  # path to checkpoint, None if none
    else:
        dest_dir = os.path.join(
            '/data/medg/misc/liuguanx/TieNet/models',
            datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S-%f'))
        os.makedirs(dest_dir)
        checkpoint = None
    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None
        if (tienet):
            jointlearner = JointLearning(num_global_att=num_global_att,
                                         s=s,
                                         decoder_dim=decoder_dim,
                                         label_size=label_size)
            jointlearner_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, jointlearner.parameters()),
                                                      lr=jointlearning_lr)
        else:
            jointlearner = None
            jointlearner_optimizer = None

    else:
        checkpoint = torch.load(checkpoint)
        print('checkpoint loaded')
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['best_bleu']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        jointlearner = checkpoint['jointlearner']
        jointlearner_optimizer = checkpoint['jointlearner_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    if torch.cuda.device_count() > 1:
        print('Using', torch.cuda.device_count(), 'GPUs')
        # decoder = nn.DataParallel(decoder)
        encoder = nn.DataParallel(encoder, device_ids=[1])
        if tienet:
            jointlearner = nn.DataParallel(jointlearner, device_ids=[1])
    decoder = decoder.to(device)
    encoder = encoder.to(device)
    if tienet:
        jointlearner = jointlearner.to(device)

    # Loss function
    criterion_R = nn.CrossEntropyLoss().to(device)
    criterion_C = nn.BCEWithLogitsLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if tienet:
                adjust_learning_rate(jointlearner_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              jointlearner=jointlearner,
              criterion_R=criterion_R,
              criterion_C=criterion_C,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              jointlearner_optimizer=jointlearner_optimizer,
              epoch=epoch,
              dest_dir=dest_dir,
              tienet=tienet)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                jointlearner=jointlearner,
                                criterion_R=criterion_R,
                                criterion_C=criterion_C,
                                tienet=tienet)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, jointlearner, encoder_optimizer,
                        decoder_optimizer, jointlearner_optimizer,
                        recent_bleu4, best_bleu4, is_best, dest_dir)
src_train_loader = get_dataloader(data_dir, src_dir, batch_size, train=True)
tgt_train_loader = get_dataloader(data_dir,
                                  tgt_train_dir,
                                  batch_size,
                                  train=True)
criterion = nn.CrossEntropyLoss()
# criterion_kl = nn.KLDivLoss()
if cuda:
    criterion = criterion.cuda()
    cl_classifier = cl_classifier.cuda()
    dm_classifier = dm_classifier.cuda()
    encoder = encoder.cuda()
soft_labels = gen_soft_labels(31, src_train_loader, encoder, cl_classifier)
# optimizer
optimizer = optim.SGD(list(encoder.parameters()) +
                      list(cl_classifier.parameters()),
                      lr=lr,
                      momentum=momentum)

optimizer_conf = optim.SGD(encoder.parameters(), lr=lr, momentum=momentum)

optimizer_dm = optim.SGD(dm_classifier.parameters(), lr=lr, momentum=momentum)
# begin training
encoder.train()
cl_classifier.train()
dm_classifier.train()
for epoch in range(1, epochs + 1):
    correct = 0
    for batch_idx, ((src_data, src_label_cl),
                    (tgt_data, tgt_label_cl)) in enumerate(
Пример #15
0
class Trainer:
    def __init__(self, device, dset, x_dim, c_dim, z_dim, n_train, n_test, lr,
                 layer_sizes, **kwargs):
        '''
        Trainer class
        Args:
            device (torch.device) : Use GPU or CPU
            x_dim (int)           : Feature dimension
            c_dim (int)           : Attribute dimension
            z_dim (int)           : Latent dimension
            n_train (int)         : Number of training classes
            n_test (int)          : Number of testing classes
            lr (float)            : Learning rate for VAE
            layer_sizes(dict)     : List containing the hidden layer sizes
            **kwargs              : Flags for using various regularizations
        '''
        self.device = device
        self.dset = dset
        self.lr = lr
        self.z_dim = z_dim

        self.n_train = n_train
        self.n_test = n_test
        self.gzsl = kwargs.get('gzsl', False)
        if self.gzsl:
            self.n_test = n_train + n_test

        # flags for various regularizers
        self.use_da = kwargs.get('use_da', False)
        self.use_ca = kwargs.get('use_ca', False)
        self.use_support = kwargs.get('use_support', False)

        self.x_encoder = Encoder(x_dim, layer_sizes['x_enc'],
                                 z_dim).to(self.device)
        self.x_decoder = Decoder(z_dim, layer_sizes['x_dec'],
                                 x_dim).to(self.device)

        self.c_encoder = Encoder(c_dim, layer_sizes['c_enc'],
                                 z_dim).to(self.device)
        self.c_decoder = Decoder(z_dim, layer_sizes['c_dec'],
                                 c_dim).to(self.device)

        self.support_classifier = Classifier(z_dim,
                                             self.n_train).to(self.device)

        params = list(self.x_encoder.parameters()) + \
                 list(self.x_decoder.parameters()) + \
                 list(self.c_encoder.parameters()) + \
                 list(self.c_decoder.parameters())

        if self.use_support:
            params += list(self.support_classifier.parameters())

        self.optimizer = optim.Adam(params, lr=lr)

        self.final_classifier = Classifier(z_dim, self.n_test).to(self.device)
        self.final_cls_optim = optim.RMSprop(
            self.final_classifier.parameters(), lr=2e-4)
        self.criterion = nn.CrossEntropyLoss()

        self.vae_save_path = './saved_models'
        self.disc_save_path = './saved_models/disc_model_%s.pth' % self.dset

    def fit_VAE(self, x, c, y, ep):
        '''
        Train on 1 minibatch of data
        Args:
            x (torch.Tensor) : Features of size (batch_size, 2048)
            c (torch.Tensor) : Attributes of size (batch_size, attr_dim)
            y (torch.Tensor) : Target labels of size (batch_size,)
            ep (int)         : Epoch number
        Returns:
            Loss for the minibatch -
            3-tuple with (vae_loss, distributn loss, cross_recon loss)
        '''
        self.anneal_parameters(ep)

        x = Variable(x.float()).to(self.device)
        c = Variable(c.float()).to(self.device)
        y = Variable(y.long()).to(self.device)

        # VAE for image embeddings
        mu_x, logvar_x = self.x_encoder(x)
        z_x = self.reparameterize(mu_x, logvar_x)
        x_recon = self.x_decoder(z_x)

        # VAE for class embeddings
        mu_c, logvar_c = self.c_encoder(c)
        z_c = self.reparameterize(mu_c, logvar_c)
        c_recon = self.c_decoder(z_c)

        # reconstruction loss
        L_recon_x = self.compute_recon_loss(x, x_recon)
        L_recon_c = self.compute_recon_loss(c, c_recon)

        # KL divergence loss
        D_kl_x = self.compute_kl_div(mu_x, logvar_x)
        D_kl_c = self.compute_kl_div(mu_c, logvar_c)

        # VAE Loss = recon_loss - KL_Divergence_loss
        L_vae_x = L_recon_x - self.beta * D_kl_x
        L_vae_c = L_recon_c - self.beta * D_kl_c
        L_vae = L_vae_x + L_vae_c

        # calculate cross alignment loss
        L_ca = torch.zeros(1).to(self.device)
        if self.use_ca:
            x_recon_from_c = self.x_decoder(z_c)
            L_ca_x = self.compute_recon_loss(x, x_recon_from_c)

            c_recon_from_x = self.c_decoder(z_x)
            L_ca_c = self.compute_recon_loss(c, c_recon_from_x)

            L_ca = L_ca_x + L_ca_c

        # calculate distribution alignment loss
        L_da = torch.zeros(1).to(self.device)
        if self.use_da:
            L_da = 2 * self.compute_da_loss(mu_x, logvar_x, mu_c, logvar_c)

        # calculate loss from support classifier
        L_sup = torch.zeros(1).to(self.device)
        if self.use_support:
            y_prob = F.softmax(self.support_classifier(z_x), dim=0)
            log_prob = torch.log(torch.gather(y_prob, 1, y.unsqueeze(1)))
            L_sup = -1 * torch.mean(log_prob)

        total_loss = L_vae + self.gamma * L_ca + self.delta * L_da + self.alpha * L_sup

        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()

        return L_vae.item(), L_da.item(), L_ca.item()

    def reparameterize(self, mu, log_var):
        '''
        Reparameterization trick using unimodal gaussian
        '''
        # eps = Variable(torch.randn(mu.size())).to(self.device)
        eps = Variable(torch.randn(mu.size()[0],
                                   1).expand(mu.size())).to(self.device)
        z = mu + torch.exp(log_var / 2.0) * eps
        return z

    def anneal_parameters(self, epoch):
        '''
        Change weight factors of various losses based on epoch number
        '''
        # weight of kl divergence loss
        if epoch <= 90:
            self.beta = 0.0026 * epoch

        # weight of Cross Alignment loss
        if epoch < 20:
            self.gamma = 0
        if epoch >= 20 and epoch <= 75:
            self.gamma = 0.044 * (epoch - 20)

        # weight of distribution alignment loss
        if epoch < 5:
            self.delta = 0
        if epoch >= 5 and epoch <= 22:
            self.delta = 0.54 * (epoch - 5)

        # weight of support loss
        if epoch < 5:
            self.alpha = 0
        else:
            self.alpha = 0.01

    def compute_recon_loss(self, x, x_recon):
        '''
        Compute the reconstruction error.
        '''
        l1_loss = torch.abs(x - x_recon).sum()
        # l1_loss = torch.abs(x - x_recon).sum(dim=1).mean()
        return l1_loss

    def compute_kl_div(self, mu, log_var):
        '''
        Compute KL Divergence between N(mu, var) & N(0, 1).
        '''
        kld = 0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum()
        # kld = 0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(dim=1).mean()
        return kld

    def compute_da_loss(self, mu1, log_var1, mu2, log_var2):
        '''
        Computes Distribution Alignment loss between 2 normal distributions.
        Uses Wasserstein distance as distance measure.
        '''
        l1 = (mu1 - mu2).pow(2).sum(dim=1)

        std1 = (log_var1 / 2.0).exp()
        std2 = (log_var2 / 2.0).exp()
        l2 = (std1 - std2).pow(2).sum(dim=1)

        l_da = torch.sqrt(l1 + l2).sum()
        return l_da

    def fit_final_classifier(self, x, y):
        '''
        Train the final classifier on synthetically generated data
        '''
        x = Variable(x.float()).to(self.device)
        y = Variable(y.long()).to(self.device)

        logits = self.final_classifier(x)
        loss = self.criterion(logits, y)

        self.final_cls_optim.zero_grad()
        loss.backward()
        self.final_cls_optim.step()

        return loss.item()

    def fit_MOE(self, x, y):
        '''
        Trains the synthetic dataset on a MoE model
        '''

    def get_vae_savename(self):
        '''
        Returns a string indicative of various flags used during training and
        dataset used. Works as a unique name for saving models
        '''
        flags = ''
        if self.use_da:
            flags += '-da'
        if self.use_ca:
            flags += '-ca'
        if self.use_support:
            flags += '-support'
        model_name = 'vae_model__dset-%s__lr-%f__z-%d__%s.pth' % (
            self.dset, self.lr, self.z_dim, flags)
        return model_name

    def save_VAE(self, ep):
        state = {
            'epoch': ep,
            'x_encoder': self.x_encoder.state_dict(),
            'x_decoder': self.x_decoder.state_dict(),
            'c_encoder': self.c_encoder.state_dict(),
            'c_decoder': self.c_decoder.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        model_name = self.get_vae_savename()
        torch.save(state, os.path.join(self.vae_save_path, model_name))

    def load_models(self, model_path=''):
        if model_path is '':
            model_path = os.path.join(self.vae_save_path,
                                      self.get_vae_savename())

        ep = 0
        if os.path.exists(model_path):
            checkpoint = torch.load(model_path)
            self.x_encoder.load_state_dict(checkpoint['x_encoder'])
            self.x_decoder.load_state_dict(checkpoint['x_decoder'])
            self.c_encoder.load_state_dict(checkpoint['c_encoder'])
            self.c_decoder.load_state_dict(checkpoint['c_decoder'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            ep = checkpoint['epoch']

        return ep

    def create_syn_dataset(self,
                           test_labels,
                           attributes,
                           seen_dataset,
                           n_samples=400):
        '''
        Creates a synthetic dataset based on attribute vectors of unseen class
        Args:
            test_labels: A dict with key as original serial number in provided
                dataset and value as the index which is predicted during
                classification by network
            attributes: A np array containing class attributes for each class
                of dataset
            seen_dataset: A list of 3-tuple (x, _, y) where x belongs to one of the
                seen classes and y is corresponding label. Used for generating
                latent representations of seen classes in GZSL
            n_samples: Number of samples of each unseen class to be generated(Default: 400)
        Returns:
            A list of 3-tuple (z, _, y) where z is latent representations and y is
            corresponding label
        '''
        syn_dataset = []
        for test_cls, idx in test_labels.items():
            attr = attributes[test_cls - 1]

            self.c_encoder.eval()
            c = Variable(torch.FloatTensor(attr).unsqueeze(0)).to(self.device)
            mu, log_var = self.c_encoder(c)

            Z = torch.cat(
                [self.reparameterize(mu, log_var) for _ in range(n_samples)])

            syn_dataset.extend([(Z[i], test_cls, idx)
                                for i in range(n_samples)])

        if seen_dataset is not None:
            self.x_encoder.eval()
            for (x, att_idx, y) in seen_dataset:
                x = Variable(torch.FloatTensor(x).unsqueeze(0)).to(self.device)
                mu, log_var = self.x_encoder(x)
                z = self.reparameterize(mu, log_var).squeeze()
                syn_dataset.append((z, att_idx, y))

        return syn_dataset

    def compute_accuracy(self, generator):
        y_real_list, y_pred_list = [], []

        for idx, (x, _, y) in enumerate(generator):
            x = Variable(x.float()).to(self.device)
            y = Variable(y.long()).to(self.device)

            self.final_classifier.eval()
            self.x_encoder.eval()
            mu, log_var = self.x_encoder(x)
            logits = self.final_classifier(mu)

            _, y_pred = logits.max(dim=1)

            y_real = y.detach().cpu().numpy()
            y_pred = y_pred.detach().cpu().numpy()

            y_real_list.extend(y_real)
            y_pred_list.extend(y_pred)

        ## We have sequence of real and predicted labels
        ## find seen and unseen classes accuracy

        if self.gzsl:
            y_real_list = np.asarray(y_real_list)
            y_pred_list = np.asarray(y_pred_list)

            y_seen_real = np.extract(y_real_list < self.n_train, y_real_list)
            y_seen_pred = np.extract(y_real_list < self.n_train, y_pred_list)

            y_unseen_real = np.extract(y_real_list >= self.n_train,
                                       y_real_list)
            y_unseen_pred = np.extract(y_real_list >= self.n_train,
                                       y_pred_list)

            acc_seen = accuracy_score(y_seen_real, y_seen_pred)
            acc_unseen = accuracy_score(y_unseen_real, y_unseen_pred)

            return acc_seen, acc_unseen

        else:
            return accuracy_score(y_real_list, y_pred_list)
Пример #16
0
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    #word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    #with open(word_map_file, 'r') as j:
    #    word_map = json.load(j)

    with open("/content/image_captioning/Image-Captioning-Codebase/vocab.pkl",
              "rb") as f:
        vocab = pickle.load(f)

    word_map = vocab.word2idx

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform_train = transforms.Compose(
        [  # smaller edge of image resized to 256
            transforms.Resize(
                (224, 224)),  # get 224x224 crop from random location
            transforms.RandomHorizontalFlip(
            ),  # horizontally flip image with probability=0.5
            transforms.ToTensor(),  # convert the PIL Image to a tensor
            transforms.Normalize(
                (0.485, 0.456, 0.406),  # normalize image for pre-trained model
                (0.229, 0.224, 0.225))
        ])
    """
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    """
    train_loader = torch.utils.data.DataLoader(
        Flickr8kDataset(annot_path="/content/", img_path="/content/Flicker8k_Dataset/", \
                            split="train", transform=transform_train), \
                            batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        Flickr8kDataset(annot_path="/content/", img_path="/content/Flicker8k_Dataset/", \
                            split="dev", transform=transform_train), \
                            batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
#     model = SpGAT(nfeat=3,
#                 nhid=args.hidden,
#                 nclass=12,
#                 dropout=args.dropout,
#                 nheads=args.nb_heads,
#                 alpha=args.alpha)
# else:
#     model = GAT(nfeat=3,
#                 nhid=args.hidden,
#                 nclass=12,
#                 dropout=args.dropout,
#                 nheads=args.nb_heads,
#                 alpha=args.alpha)
model = Encoder(output_size=(7, 7), spatial_scale=1.0, hidden=args.hidden, nclass=12,
                dropout=args.dropout, nb_heads=args.nb_heads,  alpha=args.alpha)
optimizer = optim.SGD(model.parameters(),
                       lr=args.lr)

if args.cuda:
    model.cuda()
    # features = features.cuda()
    # adj = adj.cuda()
    # labels = labels.cuda()
    # idx_train = idx_train.cuda()
    # idx_val = idx_val.cuda()
    # idx_test = idx_test.cuda()

# features, adj, labels = Variable(features), Variable(adj), Variable(labels)


def train(epoch, train_loader, val_loader, logger=None):
Пример #18
0
    if args.label_type == 'director':
        criterion = nn.CrossEntropyLoss()
        num_classes = len(directors)
    elif args.label_type == 'genre':
        criterion = nn.BCEWithLogitsLoss()
        num_classes = len(genres)

    vgg = None
    if args.model_type == 'vgg-pretrained':
        vgg = VGG16(requires_grad=False).to(args.device)
        model = Encoder(num_classes=num_classes,
                        style_dim=style_dims[args.gram_ix])
    else:
        model = BasicNet(num_classes=num_classes, in_channels=3)
    model.to(args.device)
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

    old_epoch = 0

    if args.load_model or args.eval:
        checkpoint = torch.load(PATH)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        old_epoch = checkpoint['epoch']

    # Load data
    print('Loading data...')
    dataloader_train = load_movie_data(split='train',
                                       label_type=args.label_type,
                                       batch_size=args.batch_size)
    dataloader_val = load_movie_data(split='val',
Пример #19
0
def main():
    """
    Training and validation.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_folder",
        default='data/',
        type=str,
        help="folder with data files saved by create_input_files.py")
    parser.add_argument("--data_name",
                        default='coco_5_cap_per_img_5_min_word_freq',
                        type=str,
                        help="base name shared by data files")
    parser.add_argument("--output_dir",
                        default='saved_models/',
                        type=str,
                        help="path to save checkpoints")
    parser.add_argument("--checkpoint",
                        default=None,
                        type=str,
                        help="path to checkpoint")
    parser.add_argument("--emb_dim",
                        default=512,
                        type=int,
                        help="dimension of word embeddings")
    parser.add_argument("--attention_dim",
                        default=512,
                        type=int,
                        help="dimension of attention linear layers")
    parser.add_argument("--decoder_dim",
                        default=512,
                        type=int,
                        help="dimension of decoder RNN")
    parser.add_argument("--dropout",
                        default=0.5,
                        type=float,
                        help="dimension of word embeddings")
    parser.add_argument("--start_epoch", default=0, type=int)
    parser.add_argument(
        "--epochs",
        default=120,
        type=int,
        help=
        "number of epochs to train for (if early stopping is not triggered)")
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for training and testing")
    parser.add_argument("--workers",
                        default=8,
                        type=int,
                        help="num of workers for data-loading")
    parser.add_argument("--encoder_lr", default=1e-4, type=float)
    parser.add_argument("--decoder_lr", default=5e-4, type=float)
    parser.add_argument("--grad_clip",
                        default=5,
                        type=float,
                        help="clip gradients at an absolute value of")
    parser.add_argument(
        "--alpha_c",
        default=1,
        type=int,
        help=
        "regularization parameter for 'doubly stochastic attention', as in the paper"
    )
    parser.add_argument(
        "--print_freq",
        default=100,
        type=int,
        help="print training/validation stats every __ batches")
    parser.add_argument("--fine_tune_encoder",
                        action='store_true',
                        help="Whether to finetune the encoder")
    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.device = device

    best_bleu4 = 0

    epochs_since_improvement = 0

    # Read word map
    word_map_file = os.path.join(args.data_folder,
                                 'WORDMAP_' + args.data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if args.checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=args.attention_dim,
                                       embed_dim=args.emb_dim,
                                       decoder_dim=args.decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=args.dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=args.decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(args.fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=args.encoder_lr) if args.fine_tune_encoder else None

    else:
        checkpoint = torch.load(args.checkpoint)
        args.start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if args.fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(args.fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=args.encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(args.device)
    encoder = encoder.to(args.device)

    # Loss function
    criterion = nn.CrossEntropyLoss(ignore_index=0).to(args.device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        args.data_folder,
        args.data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    print(f'train dataset length {len(train_loader)}')
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        args.data_folder,
        args.data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.workers,
                                             pin_memory=True)
    print(f'val dataset length {len(val_loader)}')

    # Epochs
    for epoch in range(args.start_epoch, args.epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if args.fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch,
              args=args)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion,
                                word_map=word_map,
                                args=args)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(args.data_name, args.output_dir, epoch,
                        epochs_since_improvement, encoder, decoder,
                        encoder_optimizer, decoder_optimizer, recent_bleu4,
                        is_best)
def main(config, needs_save):
    os.environ['CUDA_VISIBLE_DEVICES'] = config.training.visible_devices
    seed = check_manual_seed(config.training.seed)
    print('Using manual seed: {}'.format(seed))

    if config.dataset.patient_ids == 'TRAIN_PATIENT_IDS':
        patient_ids = TRAIN_PATIENT_IDS
    elif config.dataset.patient_ids == 'TEST_PATIENT_IDS':
        patient_ids = TEST_PATIENT_IDS
    else:
        raise NotImplementedError

    data_loader = get_data_loader(
        mode=config.dataset.mode,
        dataset_name=config.dataset.name,
        patient_ids=patient_ids,
        root_dir_path=config.dataset.root_dir_path,
        use_augmentation=config.dataset.use_augmentation,
        batch_size=config.dataset.batch_size,
        num_workers=config.dataset.num_workers,
        image_size=config.dataset.image_size)

    E = Encoder(input_dim=config.model.input_dim,
                z_dim=config.model.z_dim,
                filters=config.model.enc_filters,
                activation=config.model.enc_activation).float()

    D = Decoder(input_dim=config.model.input_dim,
                z_dim=config.model.z_dim,
                filters=config.model.dec_filters,
                activation=config.model.dec_activation,
                final_activation=config.model.dec_final_activation).float()

    if config.model.enc_spectral_norm:
        apply_spectral_norm(E)

    if config.model.dec_spectral_norm:
        apply_spectral_norm(D)

    if config.training.use_cuda:
        E.cuda()
        D.cuda()
        E = nn.DataParallel(E)
        D = nn.DataParallel(D)

    if config.model.saved_E:
        print(config.model.saved_E)
        E.load_state_dict(torch.load(config.model.saved_E))

    if config.model.saved_D:
        print(config.model.saved_D)
        D.load_state_dict(torch.load(config.model.saved_D))

    print(E)
    print(D)

    e_optim = optim.Adam(filter(lambda p: p.requires_grad, E.parameters()),
                         config.optimizer.enc_lr, [0.9, 0.9999])

    d_optim = optim.Adam(filter(lambda p: p.requires_grad, D.parameters()),
                         config.optimizer.dec_lr, [0.9, 0.9999])

    alpha = config.training.alpha
    beta = config.training.beta
    margin = config.training.margin

    batch_size = config.dataset.batch_size
    fixed_z = torch.randn(calc_latent_dim(config))

    if 'ssim' in config.training.loss:
        ssim_loss = pytorch_ssim.SSIM(window_size=11)

    def l_recon(recon: torch.Tensor, target: torch.Tensor):
        if config.training.loss == 'l2':
            loss = F.mse_loss(recon, target, reduction='sum')

        elif config.training.loss == 'l1':
            loss = F.l1_loss(recon, target, reduction='sum')

        elif config.training.loss == 'ssim':
            loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon)

        elif config.training.loss == 'ssim+l1':
            loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) \
                 + F.l1_loss(recon, target, reduction='sum')

        elif config.training.loss == 'ssim+l2':
            loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) \
                 + F.mse_loss(recon, target, reduction='sum')

        else:
            raise NotImplementedError

        return beta * loss / batch_size

    def l_reg(mu: torch.Tensor, log_var: torch.Tensor):
        loss = -0.5 * torch.sum(1 + log_var - mu**2 - torch.exp(log_var))
        return loss / batch_size

    def update(engine, batch):
        E.train()
        D.train()

        image = norm(batch['image'])

        if config.training.use_cuda:
            image = image.cuda(non_blocking=True).float()
        else:
            image = image.float()

        e_optim.zero_grad()
        d_optim.zero_grad()

        z, z_mu, z_logvar = E(image)
        x_r = D(z)

        l_vae_reg = l_reg(z_mu, z_logvar)
        l_vae_recon = l_recon(x_r, image)
        l_vae_total = l_vae_reg + l_vae_recon

        l_vae_total.backward()

        e_optim.step()
        d_optim.step()

        if config.training.use_cuda:
            torch.cuda.synchronize()

        return {
            'TotalLoss': l_vae_total.item(),
            'EncodeLoss': l_vae_reg.item(),
            'ReconLoss': l_vae_recon.item(),
        }

    output_dir = get_output_dir_path(config)
    trainer = Engine(update)
    timer = Timer(average=True)

    monitoring_metrics = ['TotalLoss', 'EncodeLoss', 'ReconLoss']

    for metric in monitoring_metrics:
        RunningAverage(alpha=0.98,
                       output_transform=partial(lambda x, metric: x[metric],
                                                metric=metric)).attach(
                                                    trainer, metric)

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    @trainer.on(Events.STARTED)
    def save_config(engine):
        config_to_save = defaultdict(dict)

        for key, child in config._asdict().items():
            for k, v in child._asdict().items():
                config_to_save[key][k] = v

        config_to_save['seed'] = seed
        config_to_save['output_dir'] = output_dir

        print('Training starts by the following configuration: ',
              config_to_save)

        if needs_save:
            save_path = os.path.join(output_dir, 'config.json')
            with open(save_path, 'w') as f:
                json.dump(config_to_save, f)

    @trainer.on(Events.ITERATION_COMPLETED)
    def show_logs(engine):
        if (engine.state.iteration - 1) % config.save.log_iter_interval == 0:
            columns = ['epoch', 'iteration'] + list(
                engine.state.metrics.keys())
            values = [str(engine.state.epoch), str(engine.state.iteration)] \
                   + [str(value) for value in engine.state.metrics.values()]

            message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(
                epoch=engine.state.epoch,
                max_epoch=config.training.n_epochs,
                i=engine.state.iteration,
                max_i=len(data_loader))

            for name, value in zip(columns, values):
                message += ' | {name}: {value}'.format(name=name, value=value)

            pbar.log_message(message)

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_logs(engine):
        if needs_save:
            fname = os.path.join(output_dir, 'logs.tsv')
            columns = ['epoch', 'iteration'] + list(
                engine.state.metrics.keys())
            values = [str(engine.state.epoch), str(engine.state.iteration)] \
                   + [str(value) for value in engine.state.metrics.values()]

            with open(fname, 'a') as f:
                if f.tell() == 0:
                    print('\t'.join(columns), file=f)
                print('\t'.join(values), file=f)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format(
            engine.state.epoch, timer.value()))
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_images(engine):
        if needs_save:
            if engine.state.epoch % config.save.save_epoch_interval == 0:
                image = norm(engine.state.batch['image'])

                with torch.no_grad():
                    z, _, _ = E(image)
                    x_r = D(z)
                    x_p = D(fixed_z)

                image = denorm(image).detach().cpu()
                x_r = denorm(x_r).detach().cpu()
                x_p = denorm(x_p).detach().cpu()

                image = image[:config.save.n_save_images, ...]
                x_r = x_r[:config.save.n_save_images, ...]
                x_p = x_p[:config.save.n_save_images, ...]

                save_path = os.path.join(
                    output_dir, 'result_{}.png'.format(engine.state.epoch))
                save_image(torch.cat([image, x_r, x_p]).data, save_path)

    if needs_save:
        checkpoint_handler = ModelCheckpoint(
            output_dir,
            config.save.study_name,
            save_interval=config.save.save_epoch_interval,
            n_saved=config.save.n_saved,
            create_dir=True,
        )
        trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                  handler=checkpoint_handler,
                                  to_save={
                                      'E': E,
                                      'D': D
                                  })

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    print('Training starts: [max_epochs] {}, [max_iterations] {}'.format(
        config.training.n_epochs, config.training.n_epochs * len(data_loader)))

    trainer.run(data_loader, config.training.n_epochs)
Пример #21
0
def main():
    """
    Training and validation.
    """

    # In Python, global keyword allows you to modify the variable outside of the current scope.
    # It is used to create a global variable and make changes to the variable in a local context.
    '''
    The basic rules for global keyword in Python are:

    When we create a variable inside a function, it is local by default.
    When we define a variable outside of a function, it is global by default. You don't have to use global keyword.
    We use global keyword to read and write a global variable inside a function.
    Use of global keyword outside a function has no effect.
    
    '''

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        '''
        The filter() method constructs an iterator from elements of an iterable for which a function returns true.

        The filter() method takes two parameters:

        function - function that tests if elements of an iterable returns true or false
                    If None, the function defaults to Identity function - which returns false if any elements are false

        iterable - iterable which is to be filtered, could be sets, lists, tuples, or containers of any iterators

        The filter() method returns an iterator that passed the function check for each element in the iterable.

        '''

        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # If there's no improvement in Bleu score for 20 epochs then stop training
        if epochs_since_improvement == 20:
            break

        # If there's no improvement in Bleu score for 8 epochs lower the lr
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
print(datetime.now(), "val dataset done")
#sys.stdout.flush()

dataloader_tng = DataLoader(dataset_tng,
                            batch_size=BS,
                            shuffle=True,
                            num_workers=8)
print(datetime.now(), "tng dataloader done", len(dataloader_tng))
dataloader_val = DataLoader(dataset_val,
                            batch_size=BS,
                            shuffle=True,
                            num_workers=8)
print(datetime.now(), "val dataloader done", len(dataloader_val))

encoder = Encoder(dropout=dropout, n_img=n_img).to(device)
encoder_parameters = [p for p in encoder.parameters()]
optimizer = optim.Adam(encoder_parameters[1:], lr=LR, weight_decay=L2)
#scheduler = optim.lr_scheduler.LambdaLR(optimizer, step_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                 factor=0.1,
                                                 patience=10)
loss_function = MRILoss(device, k_cosine=k_cosine)

losses_tng = []
losses_val = []
examples_pred = []
examples_gt = []

#print(datetime.now(), "begin training")
#sys.stdout.flush()
Пример #23
0
print("Initializing model parameters!")
# Initialise model parameters randomly
transition_model = TransitionModel(
    args.belief_size, args.state_size, env.action_size, args.hidden_size,
    args.embedding_size, args.activation_function).to(device=args.device)
observation_model = ObservationModel(
    args.symbolic_env, env.observation_size, args.belief_size, args.state_size,
    args.embedding_size, args.activation_function).to(device=args.device)
reward_model = RewardModel(args.belief_size, args.state_size, args.hidden_size,
                           args.activation_function).to(device=args.device)
encoder = Encoder(args.symbolic_env, env.observation_size, args.embedding_size,
                  args.activation_function).to(device=args.device)
param_list = list(transition_model.parameters()) + list(
    observation_model.parameters()) + list(reward_model.parameters()) + list(
        encoder.parameters())
optimiser = optim.Adam(param_list, lr=args.learning_rate, eps=1e-4)
if args.load_checkpoint > 0:
    model_dicts = torch.load(
        os.path.join(results_dir, 'models_%d.pth' % args.load_checkpoint))
    transition_model.load_state_dict(model_dicts['transition_model'])
    observation_model.load_state_dict(model_dicts['observation_model'])
    reward_model.load_state_dict(model_dicts['reward_model'])
    encoder.load_state_dict(model_dicts['encoder'])
    optimiser.load_state_dict(model_dicts['optimiser'])

mode = "continuous"
num_actions = -1
if type(env._env.action_space) == gym.spaces.discrete.Discrete:
    mode = "discrete"
    num_actions = env._env.action_space.n
Пример #24
0
class Trainer():
    def __init__(self, params, experience_replay_buffer,metrics,results_dir,env):
        self.parms = params     
        self.D = experience_replay_buffer  
        self.metrics = metrics
        self.env = env
        self.tested_episodes = 0

        self.statistics_path = results_dir+'/statistics' 
        self.model_path = results_dir+'/model' 
        self.video_path = results_dir+'/video' 
        self.rew_vs_pred_rew_path = results_dir+'/rew_vs_pred_rew'
        self.dump_plan_path = results_dir+'/dump_plan'
        
        #if folder do not exists, create it
        os.makedirs(self.statistics_path, exist_ok=True) 
        os.makedirs(self.model_path, exist_ok=True) 
        os.makedirs(self.video_path, exist_ok=True) 
        os.makedirs(self.rew_vs_pred_rew_path, exist_ok=True) 
        os.makedirs(self.dump_plan_path, exist_ok=True) 
        

        # Create models
        self.transition_model = TransitionModel(self.parms.belief_size, self.parms.state_size, self.env.action_size, self.parms.hidden_size, self.parms.embedding_size, self.parms.activation_function).to(device=self.parms.device)
        self.observation_model = ObservationModel(self.parms.belief_size, self.parms.state_size, self.parms.embedding_size, self.parms.activation_function).to(device=self.parms.device)
        self.reward_model = RewardModel(self.parms.belief_size, self.parms.state_size, self.parms.hidden_size, self.parms.activation_function).to(device=self.parms.device)
        self.encoder = Encoder(self.parms.embedding_size,self.parms.activation_function).to(device=self.parms.device)
        self.param_list = list(self.transition_model.parameters()) + list(self.observation_model.parameters()) + list(self.reward_model.parameters()) + list(self.encoder.parameters()) 
        self.optimiser = optim.Adam(self.param_list, lr=0 if self.parms.learning_rate_schedule != 0 else self.parms.learning_rate, eps=self.parms.adam_epsilon)
        self.planner = MPCPlanner(self.env.action_size, self.parms.planning_horizon, self.parms.optimisation_iters, self.parms.candidates, self.parms.top_candidates, self.transition_model, self.reward_model,self.env.action_range[0], self.env.action_range[1])

        global_prior = Normal(torch.zeros(self.parms.batch_size, self.parms.state_size, device=self.parms.device), torch.ones(self.parms.batch_size, self.parms.state_size, device=self.parms.device))  # Global prior N(0, I)
        self.free_nats = torch.full((1, ), self.parms.free_nats, dtype=torch.float32, device=self.parms.device)  # Allowed deviation in KL divergence

    def load_checkpoints(self):
        self.metrics = torch.load(self.model_path+'/metrics.pth')
        model_path = self.model_path+'/best_model'
        os.makedirs(model_path, exist_ok=True) 
        files = os.listdir(model_path)
        if files:
            checkpoint = [f for f in files if os.path.isfile(os.path.join(model_path, f))]
            model_dicts = torch.load(os.path.join(model_path, checkpoint[0]),map_location=self.parms.device)
            self.transition_model.load_state_dict(model_dicts['transition_model'])
            self.observation_model.load_state_dict(model_dicts['observation_model'])
            self.reward_model.load_state_dict(model_dicts['reward_model'])
            self.encoder.load_state_dict(model_dicts['encoder'])
            self.optimiser.load_state_dict(model_dicts['optimiser'])  
            print("Loading models checkpoints!")
        else:
            print("Checkpoints not found!")


    def update_belief_and_act(self, env, belief, posterior_state, action, observation, reward, min_action=-inf, max_action=inf,explore=False):
        # Infer belief over current state q(s_t|o≤t,a<t) from the history
        encoded_obs = self.encoder(observation).unsqueeze(dim=0).to(device=self.parms.device)       
        belief, _, _, _, posterior_state, _, _ = self.transition_model(posterior_state, action.unsqueeze(dim=0), belief, encoded_obs)  # Action and observation need extra time dimension
        belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(dim=0)  # Remove time dimension from belief/state
        action,pred_next_rew,_,_,_ = self.planner(belief, posterior_state,explore)  # Get action from planner(q(s_t|o≤t,a<t), p)      
        
        if explore:
            action = action + self.parms.action_noise * torch.randn_like(action)  # Add exploration noise ε ~ p(ε) to the action
        action.clamp_(min=min_action, max=max_action)  # Clip action range
        next_observation, reward, done = env.step(action.cpu() if isinstance(env, EnvBatcher) else action[0].cpu())  # If single env is istanceted perform single action (get item from list), else perform all actions
        
        return belief, posterior_state, action, next_observation, reward, done,pred_next_rew 
    
    def fit_buffer(self,episode):
        ####
        # Fit data taken from buffer 
        ######

        # Model fitting
        losses = []
        tqdm.write("Fitting buffer")
        for s in tqdm(range(self.parms.collect_interval)):

            # Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags)
            observations, actions, rewards, nonterminals = self.D.sample(self.parms.batch_size, self.parms.chunk_size)  # Transitions start at time t = 0
            # Create initial belief and state for time t = 0
            init_belief, init_state = torch.zeros(self.parms.batch_size, self.parms.belief_size, device=self.parms.device), torch.zeros(self.parms.batch_size, self.parms.state_size, device=self.parms.device)
            encoded_obs = bottle(self.encoder, (observations[1:], ))

            # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
            beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model(init_state, actions[:-1], init_belief, encoded_obs, nonterminals[:-1])
            
            # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
            # LOSS
            observation_loss = F.mse_loss(bottle(self.observation_model, (beliefs, posterior_states)), observations[1:], reduction='none').sum((2, 3, 4)).mean(dim=(0, 1))
            kl_loss = torch.max(kl_divergence(Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs)).sum(dim=2), self.free_nats).mean(dim=(0, 1))  
            reward_loss = F.mse_loss(bottle(self.reward_model, (beliefs, posterior_states)), rewards[:-1], reduction='none').mean(dim=(0, 1))            

            # Update model parameters
            self.optimiser.zero_grad()

            (observation_loss + reward_loss + kl_loss).backward() # BACKPROPAGATION
            nn.utils.clip_grad_norm_(self.param_list, self.parms.grad_clip_norm, norm_type=2)
            self.optimiser.step()
            # Store (0) observation loss (1) reward loss (2) KL loss
            losses.append([observation_loss.item(), reward_loss.item(), kl_loss.item()])#, regularizer_loss.item()])

        #save statistics and plot them
        losses = tuple(zip(*losses))  
        self.metrics['observation_loss'].append(losses[0])
        self.metrics['reward_loss'].append(losses[1])
        self.metrics['kl_loss'].append(losses[2])
      
        lineplot(self.metrics['episodes'][-len(self.metrics['observation_loss']):], self.metrics['observation_loss'], 'observation_loss', self.statistics_path)
        lineplot(self.metrics['episodes'][-len(self.metrics['reward_loss']):], self.metrics['reward_loss'], 'reward_loss', self.statistics_path)
        lineplot(self.metrics['episodes'][-len(self.metrics['kl_loss']):], self.metrics['kl_loss'], 'kl_loss', self.statistics_path)
        
    def explore_and_collect(self,episode):
        tqdm.write("Collect new data:")
        reward = 0
        # Data collection
        with torch.no_grad():
            done = False
            observation, total_reward = self.env.reset(), 0
            belief, posterior_state, action = torch.zeros(1, self.parms.belief_size, device=self.parms.device), torch.zeros(1, self.parms.state_size, device=self.parms.device), torch.zeros(1, self.env.action_size, device=self.parms.device)
            t = 0
            real_rew = []
            predicted_rew = [] 
            total_steps = self.parms.max_episode_length // self.env.action_repeat
            explore = True

            for t in tqdm(range(total_steps)):
                # Here we need to explore
                belief, posterior_state, action, next_observation, reward, done, pred_next_rew = self.update_belief_and_act(self.env, belief, posterior_state, action, observation.to(device=self.parms.device), [reward], self.env.action_range[0], self.env.action_range[1], explore=explore)
                self.D.append(observation, action.cpu(), reward, done)
                real_rew.append(reward)
                predicted_rew.append(pred_next_rew.to(device=self.parms.device).item())
                total_reward += reward
                observation = next_observation
                if self.parms.flag_render:
                    env.render()
                if done:
                    break

        # Update and plot train reward metrics
        self.metrics['steps'].append( (t * self.env.action_repeat) + self.metrics['steps'][-1])
        self.metrics['episodes'].append(episode)
        self.metrics['train_rewards'].append(total_reward)
        self.metrics['predicted_rewards'].append(np.array(predicted_rew).sum())

        lineplot(self.metrics['episodes'][-len(self.metrics['train_rewards']):], self.metrics['train_rewards'], 'train_rewards', self.statistics_path)
        double_lineplot(self.metrics['episodes'], self.metrics['train_rewards'], self.metrics['predicted_rewards'], "train_r_vs_pr", self.statistics_path)

    def train_models(self):
        # from (init_episodes) to (training_episodes + init_episodes)
        tqdm.write("Start training.")

        for episode in tqdm(range(self.parms.num_init_episodes +1, self.parms.training_episodes) ):
            self.fit_buffer(episode)       
            self.explore_and_collect(episode)
            if episode % self.parms.test_interval == 0:
                self.test_model(episode)
                torch.save(self.metrics, os.path.join(self.model_path, 'metrics.pth'))
                torch.save({'transition_model': self.transition_model.state_dict(), 'observation_model': self.observation_model.state_dict(), 'reward_model': self.reward_model.state_dict(), 'encoder': self.encoder.state_dict(), 'optimiser': self.optimiser.state_dict()},  os.path.join(self.model_path, 'models_%d.pth' % episode))
            
            if episode % self.parms.storing_dataset_interval == 0:
                self.D.store_dataset(self.parms.dataset_path+'dump_dataset')

        return self.metrics

    def test_model(self, episode=None): #no explore here
        if episode is None:
            episode = self.tested_episodes


        # Set models to eval mode
        self.transition_model.eval()
        self.observation_model.eval()
        self.reward_model.eval()
        self.encoder.eval()
        
        # Initialise parallelised test environments
        test_envs = EnvBatcher(ControlSuiteEnv, (self.parms.env_name, self.parms.seed, self.parms.max_episode_length, self.parms.bit_depth), {}, self.parms.test_episodes)
        total_steps = self.parms.max_episode_length // test_envs.action_repeat
        rewards = np.zeros(self.parms.test_episodes)
        
        real_rew = torch.zeros([total_steps,self.parms.test_episodes])
        predicted_rew = torch.zeros([total_steps,self.parms.test_episodes])

        with torch.no_grad():
            observation, total_rewards, video_frames = test_envs.reset(), np.zeros((self.parms.test_episodes, )), []            
            belief, posterior_state, action = torch.zeros(self.parms.test_episodes, self.parms.belief_size, device=self.parms.device), torch.zeros(self.parms.test_episodes, self.parms.state_size, device=self.parms.device), torch.zeros(self.parms.test_episodes, self.env.action_size, device=self.parms.device)
            tqdm.write("Testing model.")
            for t in range(total_steps):     
                belief, posterior_state, action, next_observation, rewards, done, pred_next_rew  = self.update_belief_and_act(test_envs,  belief, posterior_state, action, observation.to(device=self.parms.device), list(rewards), self.env.action_range[0], self.env.action_range[1])
                total_rewards += rewards.numpy()
                real_rew[t] = rewards
                predicted_rew[t]  = pred_next_rew

                observation = self.env.get_original_frame().unsqueeze(dim=0)

                video_frames.append(make_grid(torch.cat([observation, self.observation_model(belief, posterior_state).cpu()], dim=3) + 0.5, nrow=5).numpy())  # Decentre
                observation = next_observation
                if done.sum().item() == self.parms.test_episodes:
                    break
            
        real_rew = torch.transpose(real_rew, 0, 1)
        predicted_rew = torch.transpose(predicted_rew, 0, 1)
        
        #save and plot metrics 
        self.tested_episodes += 1
        self.metrics['test_episodes'].append(episode)
        self.metrics['test_rewards'].append(total_rewards.tolist())

        lineplot(self.metrics['test_episodes'], self.metrics['test_rewards'], 'test_rewards', self.statistics_path)
        
        write_video(video_frames, 'test_episode_%s' % str(episode), self.video_path)  # Lossy compression
        # Set models to train mode
        self.transition_model.train()
        self.observation_model.train()
        self.reward_model.train()
        self.encoder.train()
        # Close test environments
        test_envs.close()
        return self.metrics


    def dump_plan_video(self, step_before_plan=120): 
        #number of steps before to start to collect frames to dump
        step_before_plan = min(step_before_plan, (self.parms.max_episode_length // self.env.action_repeat))
        
        # Set models to eval mode
        self.transition_model.eval()
        self.observation_model.eval()
        self.reward_model.eval()
        self.encoder.eval()
        video_frames = []
        reward = 0

        with torch.no_grad():
            observation = self.env.reset()
            belief, posterior_state, action = torch.zeros(1, self.parms.belief_size, device=self.parms.device), torch.zeros(1, self.parms.state_size, device=self.parms.device), torch.zeros(1, self.env.action_size, device=self.parms.device)
            tqdm.write("Executing episode.")
            for t in range(step_before_plan): #floor division
                belief, posterior_state, action, next_observation, reward, done, _ = self.update_belief_and_act(self.env,  belief, posterior_state, action, observation.to(device=self.parms.device), [reward], self.env.action_range[0], self.env.action_range[1])
                observation = next_observation
                video_frames.append(make_grid(torch.cat([observation.cpu(), self.observation_model(belief, posterior_state).to(device=self.parms.device).cpu()], dim=3) + 0.5, nrow=5).numpy())  # Decentre
                if done:
                    break
            self.create_and_dump_plan(self.env,  belief, posterior_state, action, observation.to(device=self.parms.device), [reward], self.env.action_range[0], self.env.action_range[1])
            
            
        # Set models to train mode
        self.transition_model.train()
        self.observation_model.train()
        self.reward_model.train()
        self.encoder.train()
        # Close test environments
        self.env.close()

    def create_and_dump_plan(self, env, belief, posterior_state, action, observation, reward, min_action=-inf, max_action=inf): 

        tqdm.write("Dumping plan")
        video_frames = []

        encoded_obs = self.encoder(observation).unsqueeze(dim=0)
        belief, _, _, _, posterior_state, _, _ = self.transition_model(posterior_state, action.unsqueeze(dim=0), belief, encoded_obs)  
        belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(dim=0)  # Remove time dimension from belief/state
        next_action,_, beliefs, states, plan = self.planner(belief, posterior_state,False)  # Get action from planner(q(s_t|o≤t,a<t), p)      
        predicted_frames = self.observation_model(beliefs, states).to(device=self.parms.device)

        for i in range(self.parms.planning_horizon):
            plan[i].clamp_(min=env.action_range[0], max=self.env.action_range[1])  # Clip action range
            next_observation, reward, done = env.step(plan[i].cpu())  
            next_observation = next_observation.squeeze(dim=0)
            video_frames.append(make_grid(torch.cat([next_observation, predicted_frames[i]], dim=1) + 0.5, nrow=2).numpy())  # Decentre

        write_video(video_frames, 'dump_plan', self.dump_plan_path, dump_frame=True)  
    
            
Пример #25
0
    # Define generator, encoder and discriminators
    generator = ResNetGenerator(latent_dim,
                                img_shape,
                                n_residual_blocks,
                                device=gpu_id).to(gpu_id)
    encoder = Encoder(latent_dim).to(gpu_id)
    discriminator = PatchGANDiscriminator(img_shape).to(gpu_id)

    # init weights
    generator.apply(weights_init_normal)
    encoder.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    # Define optimizers for networks
    optimizer_E = torch.optim.Adam(encoder.parameters(),
                                   lr=lr_rate,
                                   betas=betas)
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=lr_rate,
                                   betas=betas)
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=lr_rate,
                                   betas=betas)

    # For adversarial loss (optional to use)
    valid = 1
    fake = 0

    # Train loss list
    list_vae_G_train_loss = []
Пример #26
0
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
fine_tune_encoder = True  # fine-tune encoder?
decoder = DecoderWithAttention(attention_dim=attention_dim,
                               embed_dim=emb_dim,
                               decoder_dim=decoder_dim,
                               vocab_size=len(vocab),
                               dropout=dropout)
decoder.fine_tune_embeddings(True)
decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                   decoder.parameters()),
                                     lr=decoder_lr)
encoder = Encoder()
encoder.fine_tune(fine_tune_encoder)
encoder_optimizer = torch.optim.Adam(
    params=filter(lambda p: p.requires_grad, encoder.parameters()),
    lr=encoder_lr) if fine_tune_encoder else None

decoder_sched = torch.optim.lr_scheduler.CosineAnnealingLR(decoder_optimizer,
                                                           5,
                                                           eta_min=1e-5,
                                                           last_epoch=-1)
encoder_sched = torch.optim.lr_scheduler.CosineAnnealingLR(encoder_optimizer,
                                                           5,
                                                           eta_min=1e-5,
                                                           last_epoch=-1)
encoder = encoder.cuda()
decoder = decoder.cuda()

train_loader = generate_data_loader(train_root, 64, int(150000))
val_loader = generate_data_loader(val_root, 50, int(10000))
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map (w2i)
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders (This page details the preprocessing or transformation we need to perform –
    # pixel values must be in the range [0,1] and we must then normalize the image by the mean and standard
    # deviation of the ImageNet images' RGB channels.)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    initial_time = time.time()
    print("Initial time", initial_time)

    # Epochs
    for epoch in range(start_epoch, epochs):
        print("Starting epoch ", epoch)

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch,
              initial_time=initial_time)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
Пример #28
0
    netD_A.load_state_dict(torch.load('output modified/netD_A.pth'))
    netD_B.load_state_dict(torch.load('output modified/netD_B.pth'))
else:
    encoder.apply(weights_init_normal)
    decoder_A2B.apply(weights_init_normal)
    decoder_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(encoder.parameters(),
                                               decoder_A2B.parameters(),
                                               decoder_B2A.parameters()),
                               lr=lr,
                               betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                 lr=lr,
                                 betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                 lr=lr,
                                 betas=(0.5, 0.999))

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G,
                                                   lr_lambda=LambdaLR(
                                                       n_epochs, start_epoch,
                                                       decay_epoch).step)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch_size = args.batch_size

    # image size 3, 32, 32
    # batch size must be an even number
    # shuffle must be True

    cifar_10_train_dt = CIFAR10('data',  download=True, train=True, transform=ToTensor())
    cifar_10_train_l = DataLoader(cifar_10_train_dt, batch_size=batch_size, shuffle=True, drop_last=True,
                                  pin_memory=torch.cuda.is_available())

    encoder = Encoder().to(device)
    # mMking it local only
    loss_fn = DeepInfoMaxLoss(0, 1, 0.1).to(device)
    encoder_optim = Adam(encoder.parameters(), lr=1e-4)
    loss_optim = Adam(loss_fn.parameters(), lr=1e-4)

    epoch_restart = 20
    root = Path(r'models')

    if epoch_restart > 0 and root is not None:
        enc_file = root / Path('encoder' + str(epoch_restart) + '.wgt')
        loss_file = root / Path('loss' + str(epoch_restart) + '.wgt')
        encoder.load_state_dict(torch.load(str(enc_file)))
        loss_fn.load_state_dict(torch.load(str(loss_file)))

    for epoch in range(epoch_restart + 1, 201):
        batch = tqdm(cifar_10_train_l, total=len(cifar_10_train_dt) // batch_size)
        train_loss = []
        for x, target in batch:
Пример #30
0
def train_vae(args):
    # pdb.set_trace()
    best_metric = -float("inf")

    prior_params = list([])
    varflow_params = list([])
    prior_flow = None
    variational_flow = None

    data = Dataset(args)
    if args.data in ['goodreads', 'big_dataset']:
        args.feature_shape = data.feature_shape

    if args.nf_prior:
        flows = []
        for i in range(args.num_flows_prior):
            if args.nf_prior == 'IAF':
                one_arn = AutoRegressiveNN(args.z_dim,
                                           [2 * args.z_dim]).to(args.device)
                one_flow = AffineAutoregressive(one_arn)
            elif args.nf_prior == 'RNVP':
                hypernet = DenseNN(
                    input_dim=args.z_dim // 2,
                    hidden_dims=[2 * args.z_dim, 2 * args.z_dim],
                    param_dims=[
                        args.z_dim - args.z_dim // 2,
                        args.z_dim - args.z_dim // 2
                    ]).to(args.device)
                one_flow = AffineCoupling(args.z_dim // 2,
                                          hypernet).to(args.device)
            flows.append(one_flow)
        prior_flow = nn.ModuleList(flows)
        prior_params = list(prior_flow.parameters())

    if args.data == 'mnist':
        encoder = Encoder(args).to(args.device)
    elif args.data in ['goodreads', 'big_dataset']:
        encoder = Encoder_rec(args).to(args.device)

    if args.nf_vardistr:
        flows = []
        for i in range(args.num_flows_vardistr):
            one_arn = AutoRegressiveNN(args.z_dim, [2 * args.z_dim],
                                       param_dims=[2 * args.z_dim] * 3).to(
                                           args.device)
            one_flows = NeuralAutoregressive(one_arn, hidden_units=256)
            flows.append(one_flows)
        variational_flow = nn.ModuleList(flows)
        varflow_params = list(variational_flow.parameters())

    if args.data == 'mnist':
        decoder = Decoder(args).to(args.device)
    elif args.data in ['goodreads', 'big_dataset']:
        decoder = Decoder_rec(args).to(args.device)

    params = list(encoder.parameters()) + list(
        decoder.parameters()) + prior_params + varflow_params
    optimizer = torch.optim.Adam(params=params)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=100,
                                                gamma=0.1)

    current_tolerance = 0
    # with torch.autograd.detect_anomaly():
    for ep in tqdm(range(args.num_epoches)):
        # training cycle
        for batch_num, batch_train in enumerate(data.next_train_batch()):
            batch_train_repeated = batch_train.repeat(
                *[[args.n_samples] + [1] * (len(batch_train.shape) - 1)])
            mu, sigma = encoder(batch_train_repeated)
            sum_log_sigma = torch.sum(torch.log(sigma), 1)
            sum_log_jacobian = 0.
            eps = args.std_normal.sample(mu.shape)
            z = mu + sigma * eps
            if not args.use_reparam:
                z = z.detach()
            if variational_flow:
                prev_v = z
                for flow_num in range(args.num_flows_vardistr):
                    u = variational_flow[flow_num](prev_v)
                    sum_log_jacobian += variational_flow[
                        flow_num].log_abs_det_jacobian(prev_v, u)
                    prev_v = u
                z = u
            logits = decoder(z)
            elbo = compute_objective(args=args,
                                     x_logits=logits,
                                     x_true=batch_train_repeated,
                                     sampled_noise=eps,
                                     inf_samples=z,
                                     sum_log_sigma=sum_log_sigma,
                                     prior_flow=prior_flow,
                                     sum_log_jacobian=sum_log_jacobian,
                                     mu=mu,
                                     sigma=sigma)
            (-elbo).backward()
            optimizer.step()
            optimizer.zero_grad()
        # scheduler step
        scheduler.step()

        # validation
        with torch.no_grad():
            metric = validate_vae(args=args,
                                  encoder=encoder,
                                  decoder=decoder,
                                  dataset=data,
                                  prior_flow=prior_flow,
                                  variational_flow=variational_flow)
            if (metric != metric).sum():
                print('NAN appeared!')
                raise ValueError
            if metric > best_metric:
                current_tolerance = 0
                best_metric = metric
                if not os.path.exists('./models/{}/'.format(args.data)):
                    os.makedirs('./models/{}/'.format(args.data))
                torch.save(
                    encoder,
                    './models/{}/best_encoder_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
                    .format(args.data, args.data, args.use_skips,
                            args.nf_prior, args.num_flows_prior,
                            args.nf_vardistr, args.num_flows_vardistr,
                            args.n_samples, args.z_dim, args.use_reparam))
                torch.save(
                    decoder,
                    './models/{}/best_decoder_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
                    .format(args.data, args.data, args.use_skips,
                            args.nf_prior, args.num_flows_prior,
                            args.nf_vardistr, args.num_flows_vardistr,
                            args.n_samples, args.z_dim, args.use_reparam))
                if args.nf_prior:
                    torch.save(
                        prior_flow,
                        './models/{}/best_prior_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
                        .format(args.data, args.data, args.use_skips,
                                args.nf_prior, args.num_flows_prior,
                                args.nf_vardistr, args.num_flows_vardistr,
                                args.n_samples, args.z_dim, args.use_reparam))
                if args.nf_vardistr:
                    torch.save(
                        variational_flow,
                        './models/{}/best_varflow_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
                        .format(args.data, args.data, args.use_skips,
                                args.nf_prior, args.num_flows_prior,
                                args.nf_vardistr, args.num_flows_vardistr,
                                args.n_samples, args.z_dim, args.use_reparam))
            else:
                current_tolerance += 1
                if current_tolerance >= args.early_stopping_tolerance:
                    print(
                        "Early stopping on epoch {} (effectively trained for {} epoches)"
                        .format(ep, ep - args.early_stopping_tolerance))
                    break
            print(
                'Current epoch: {}'.format(ep), '\t',
                'Current validation {}: {}'.format(args.metric_name, metric),
                '\t', 'Best validation {}: {}'.format(args.metric_name,
                                                      best_metric))

    # return best models:
    encoder = torch.load(
        './models/{}/best_encoder_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
        .format(args.data, args.data, args.use_skips, args.nf_prior,
                args.num_flows_prior, args.nf_vardistr,
                args.num_flows_vardistr, args.n_samples, args.z_dim,
                args.use_reparam))
    decoder = torch.load(
        './models/{}/best_decoder_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
        .format(args.data, args.data, args.use_skips, args.nf_prior,
                args.num_flows_prior, args.nf_vardistr,
                args.num_flows_vardistr, args.n_samples, args.z_dim,
                args.use_reparam))
    if args.nf_prior:
        prior_flow = torch.load(
            './models/{}/best_prior_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
            .format(args.data, args.data, args.use_skips, args.nf_prior,
                    args.num_flows_prior, args.nf_vardistr,
                    args.num_flows_vardistr, args.n_samples, args.z_dim,
                    args.use_reparam))
    if args.nf_vardistr:
        variational_flow = torch.load(
            './models/{}/best_varflow_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
            .format(args.data, args.data, args.use_skips, args.nf_prior,
                    args.num_flows_prior, args.nf_vardistr,
                    args.num_flows_vardistr, args.n_samples, args.z_dim,
                    args.use_reparam))
    return encoder, decoder, prior_flow, variational_flow, data