Example #1
0
def get_models():

    from models import Encoder, Decoder
    from model_config import Config

    print('Initializing configuration...')
    config = Config()

    print('Initializing models...')
    encoder = Encoder(encoder_name=config.ENCODER_NAME, show_feature_dims=True)
    decoder = Decoder(encoder_dim=encoder.encoder_dim,
                      decoder_dim=config.decoder_dim,
                      attention_dim=config.attention_dim,
                      action_dim=config.action_dim,
                      num_loc=encoder.num_loc,
                      y_keys_info=config.y_keys_info,
                      num_layers=config.num_layers,
                      dropout_prob=config.dropout_prob)
    encoder.cuda()
    decoder.cuda()

    params_list = os.listdir(config.params_dir)
    states = load_lastest_states(config.params_dir, params_list)

    encoder.load_state_dict(states['encoder'])
    decoder.load_state_dict(states['decoder'])

    return encoder, decoder, config.init_y
Example #2
0
def init_model(network_type, restore,num_classes):
    """Init models with cuda and weights."""
    # init weights of model
    if network_type == "src_encoder":
        net = Encoder(num_classes,domain="src",name=params.srcenc_name)
    elif network_type == "src_classifier":
        net = Classifier(num_classes)
    elif network_type == "tgt_encoder":
        net = Encoder(num_classes,domain="tgt",name=params.tgtenc_name)
    elif network_type == "discriminator":
        net = Discriminator(input_dims=params.d_input_dims,
                hidden_dims=params.d_hidden_dims,
                output_dims=params.d_output_dims)
    else:
        print("[util.py] INFO | Network type not implemented.")

    #TODO: Initialise with pretrained resnet18 models for our dataset.
    # net.apply(init_weights)

    # restore model weights
    if restore is not None and os.path.exists(restore):
        net.load_state_dict(torch.load(restore))
        net.restored = True
        print("[utils.py] INFO | Restore model from: {}".format(os.path.abspath(restore)))

    # check if cuda is available
    if torch.cuda.is_available():
        cudnn.benchmark = True
        net.cuda()
    net = (net.to(device))
    return net
Example #3
0
    def init_encoder(self):
        leaky_relu_parmas = self.encoder_leky_reul
        dropout_params = self.encoder_dropout

        encoder = Encoder(self.mu_dim, self.update_set_size, leaky_relu_parmas,
                          dropout_params)
        encoder.cuda()

        return encoder
Example #4
0
def pretrain(source_data_loader,
             test_data_loader,
             no_classes,
             embeddings,
             epochs=20,
             batch_size=128,
             cuda=False):

    classifier = Classifier()
    encoder = Encoder(embeddings)

    if cuda:
        classifier.cuda()
        encoder.cuda()
    ''' Jointly optimize both encoder and classifier '''
    encoder_params = filter(lambda p: p.requires_grad, encoder.parameters())
    optimizer = optim.Adam(
        list(encoder_params) + list(classifier.parameters()))

    # Use weights to normalize imbalanced in data
    c = [1] * len(no_classes)
    weights = torch.FloatTensor(len(no_classes))
    for i, (a, b) in enumerate(zip(c, no_classes)):
        weights[i] = 0 if b == 0 else a / b

    loss_fn = nn.CrossEntropyLoss(weight=Variable(weights))

    print('Training encoder and classifier')
    for e in range(epochs):

        # pretrain with whole source data -- use groups with DCD
        for sample in source_data_loader:
            x, y = Variable(sample[0]), Variable(sample[1])
            optimizer.zero_grad()

            if cuda:
                x, y = x.cuda(), y.cuda()

            output = model_fn(encoder, classifier)(x)

            loss = loss_fn(output, y)

            loss.backward()

            optimizer.step()

        print("Epoch", e, "Loss", loss.data[0], "Accuracy",
              eval_on_test(test_data_loader, model_fn(encoder, classifier)))

    return encoder, classifier
Example #5
0
def encoder_test(seq_len=4, decoder_batch_size=2, model_name='xception'):
    encoder = Encoder(seq_len=seq_len, decoder_batch_size=decoder_batch_size, model_name=model_name)
    encoder.cuda()
    encoder.eval()

    with torch.no_grad():
        images = []
        for i in range(decoder_batch_size):
            images.append(torch.rand((seq_len, 3, 299, 299)))
        images = torch.stack(images).cuda()
        features = encoder.forward(images)

        split_features = []
        for i in range(decoder_batch_size):
            split_features.append(encoder._forward_old(images[i]))
        split_features = torch.stack(split_features)

    assert(torch.all(split_features == features) == 1)
    print('encoder test passed!')
Example #6
0
def pretrain(data, epochs=5, batch_size=128, cuda=False):

    X_s, y_s, _, _ = data

    test_dataloader = mnist_dataloader(train=False, cuda=cuda)

    classifier = Classifier()
    encoder = Encoder()

    if cuda:
        classifier.cuda()
        encoder.cuda()

    ''' Jointly optimize both encoder and classifier ''' 
    optimizer = optim.Adam(list(encoder.parameters()) + list(classifier.parameters()))
    loss_fn = nn.CrossEntropyLoss()
    
    for e in range(epochs):
        
        for _ in range(len(X_s) // batch_size):
            inds = torch.randperm(len(X_s))[:batch_size]

            x, y = Variable(X_s[inds]), Variable(y_s[inds])
            optimizer.zero_grad()

            if cuda:
                x, y = x.cuda(), y.cuda()

            y_pred = model_fn(encoder, classifier)(x)

            loss = loss_fn(y_pred, y)

            loss.backward()

            optimizer.step()

        print("Epoch", e, "Loss", loss.data[0], "Accuracy", eval_on_test(test_dataloader, model_fn(encoder, classifier)))
    
    return encoder, classifier
Example #7
0
                                     lr=decoder_lr)
encoder = Encoder()
encoder.fine_tune(fine_tune_encoder)
encoder_optimizer = torch.optim.Adam(
    params=filter(lambda p: p.requires_grad, encoder.parameters()),
    lr=encoder_lr) if fine_tune_encoder else None

decoder_sched = torch.optim.lr_scheduler.CosineAnnealingLR(decoder_optimizer,
                                                           5,
                                                           eta_min=1e-5,
                                                           last_epoch=-1)
encoder_sched = torch.optim.lr_scheduler.CosineAnnealingLR(encoder_optimizer,
                                                           5,
                                                           eta_min=1e-5,
                                                           last_epoch=-1)
encoder = encoder.cuda()
decoder = decoder.cuda()

train_loader = generate_data_loader(train_root, 64, int(150000))
val_loader = generate_data_loader(val_root, 50, int(10000))
criterion = nn.CrossEntropyLoss().to(device)


class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """
    def __init__(self):
        self.reset()

    def reset(self):
def main(config, needs_save):
    os.environ['CUDA_VISIBLE_DEVICES'] = config.training.visible_devices
    seed = check_manual_seed(config.training.seed)
    print('Using manual seed: {}'.format(seed))

    if config.dataset.patient_ids == 'TRAIN_PATIENT_IDS':
        patient_ids = TRAIN_PATIENT_IDS
    elif config.dataset.patient_ids == 'TEST_PATIENT_IDS':
        patient_ids = TEST_PATIENT_IDS
    else:
        raise NotImplementedError

    data_loader = get_data_loader(
        mode=config.dataset.mode,
        dataset_name=config.dataset.name,
        patient_ids=patient_ids,
        root_dir_path=config.dataset.root_dir_path,
        use_augmentation=config.dataset.use_augmentation,
        batch_size=config.dataset.batch_size,
        num_workers=config.dataset.num_workers,
        image_size=config.dataset.image_size)

    E = Encoder(input_dim=config.model.input_dim,
                z_dim=config.model.z_dim,
                filters=config.model.enc_filters,
                activation=config.model.enc_activation).float()

    D = Decoder(input_dim=config.model.input_dim,
                z_dim=config.model.z_dim,
                filters=config.model.dec_filters,
                activation=config.model.dec_activation,
                final_activation=config.model.dec_final_activation).float()

    if config.model.enc_spectral_norm:
        apply_spectral_norm(E)

    if config.model.dec_spectral_norm:
        apply_spectral_norm(D)

    if config.training.use_cuda:
        E.cuda()
        D.cuda()
        E = nn.DataParallel(E)
        D = nn.DataParallel(D)

    if config.model.saved_E:
        print(config.model.saved_E)
        E.load_state_dict(torch.load(config.model.saved_E))

    if config.model.saved_D:
        print(config.model.saved_D)
        D.load_state_dict(torch.load(config.model.saved_D))

    print(E)
    print(D)

    e_optim = optim.Adam(filter(lambda p: p.requires_grad, E.parameters()),
                         config.optimizer.enc_lr, [0.9, 0.9999])

    d_optim = optim.Adam(filter(lambda p: p.requires_grad, D.parameters()),
                         config.optimizer.dec_lr, [0.9, 0.9999])

    alpha = config.training.alpha
    beta = config.training.beta
    margin = config.training.margin

    batch_size = config.dataset.batch_size
    fixed_z = torch.randn(calc_latent_dim(config))

    if 'ssim' in config.training.loss:
        ssim_loss = pytorch_ssim.SSIM(window_size=11)

    def l_recon(recon: torch.Tensor, target: torch.Tensor):
        if config.training.loss == 'l2':
            loss = F.mse_loss(recon, target, reduction='sum')

        elif config.training.loss == 'l1':
            loss = F.l1_loss(recon, target, reduction='sum')

        elif config.training.loss == 'ssim':
            loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon)

        elif config.training.loss == 'ssim+l1':
            loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) \
                 + F.l1_loss(recon, target, reduction='sum')

        elif config.training.loss == 'ssim+l2':
            loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) \
                 + F.mse_loss(recon, target, reduction='sum')

        else:
            raise NotImplementedError

        return beta * loss / batch_size

    def l_reg(mu: torch.Tensor, log_var: torch.Tensor):
        loss = -0.5 * torch.sum(1 + log_var - mu**2 - torch.exp(log_var))
        return loss / batch_size

    def update(engine, batch):
        E.train()
        D.train()

        image = norm(batch['image'])

        if config.training.use_cuda:
            image = image.cuda(non_blocking=True).float()
        else:
            image = image.float()

        e_optim.zero_grad()
        d_optim.zero_grad()

        z, z_mu, z_logvar = E(image)
        x_r = D(z)

        l_vae_reg = l_reg(z_mu, z_logvar)
        l_vae_recon = l_recon(x_r, image)
        l_vae_total = l_vae_reg + l_vae_recon

        l_vae_total.backward()

        e_optim.step()
        d_optim.step()

        if config.training.use_cuda:
            torch.cuda.synchronize()

        return {
            'TotalLoss': l_vae_total.item(),
            'EncodeLoss': l_vae_reg.item(),
            'ReconLoss': l_vae_recon.item(),
        }

    output_dir = get_output_dir_path(config)
    trainer = Engine(update)
    timer = Timer(average=True)

    monitoring_metrics = ['TotalLoss', 'EncodeLoss', 'ReconLoss']

    for metric in monitoring_metrics:
        RunningAverage(alpha=0.98,
                       output_transform=partial(lambda x, metric: x[metric],
                                                metric=metric)).attach(
                                                    trainer, metric)

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    @trainer.on(Events.STARTED)
    def save_config(engine):
        config_to_save = defaultdict(dict)

        for key, child in config._asdict().items():
            for k, v in child._asdict().items():
                config_to_save[key][k] = v

        config_to_save['seed'] = seed
        config_to_save['output_dir'] = output_dir

        print('Training starts by the following configuration: ',
              config_to_save)

        if needs_save:
            save_path = os.path.join(output_dir, 'config.json')
            with open(save_path, 'w') as f:
                json.dump(config_to_save, f)

    @trainer.on(Events.ITERATION_COMPLETED)
    def show_logs(engine):
        if (engine.state.iteration - 1) % config.save.log_iter_interval == 0:
            columns = ['epoch', 'iteration'] + list(
                engine.state.metrics.keys())
            values = [str(engine.state.epoch), str(engine.state.iteration)] \
                   + [str(value) for value in engine.state.metrics.values()]

            message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(
                epoch=engine.state.epoch,
                max_epoch=config.training.n_epochs,
                i=engine.state.iteration,
                max_i=len(data_loader))

            for name, value in zip(columns, values):
                message += ' | {name}: {value}'.format(name=name, value=value)

            pbar.log_message(message)

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_logs(engine):
        if needs_save:
            fname = os.path.join(output_dir, 'logs.tsv')
            columns = ['epoch', 'iteration'] + list(
                engine.state.metrics.keys())
            values = [str(engine.state.epoch), str(engine.state.iteration)] \
                   + [str(value) for value in engine.state.metrics.values()]

            with open(fname, 'a') as f:
                if f.tell() == 0:
                    print('\t'.join(columns), file=f)
                print('\t'.join(values), file=f)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format(
            engine.state.epoch, timer.value()))
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_images(engine):
        if needs_save:
            if engine.state.epoch % config.save.save_epoch_interval == 0:
                image = norm(engine.state.batch['image'])

                with torch.no_grad():
                    z, _, _ = E(image)
                    x_r = D(z)
                    x_p = D(fixed_z)

                image = denorm(image).detach().cpu()
                x_r = denorm(x_r).detach().cpu()
                x_p = denorm(x_p).detach().cpu()

                image = image[:config.save.n_save_images, ...]
                x_r = x_r[:config.save.n_save_images, ...]
                x_p = x_p[:config.save.n_save_images, ...]

                save_path = os.path.join(
                    output_dir, 'result_{}.png'.format(engine.state.epoch))
                save_image(torch.cat([image, x_r, x_p]).data, save_path)

    if needs_save:
        checkpoint_handler = ModelCheckpoint(
            output_dir,
            config.save.study_name,
            save_interval=config.save.save_epoch_interval,
            n_saved=config.save.n_saved,
            create_dir=True,
        )
        trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                  handler=checkpoint_handler,
                                  to_save={
                                      'E': E,
                                      'D': D
                                  })

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    print('Training starts: [max_epochs] {}, [max_iterations] {}'.format(
        config.training.n_epochs, config.training.n_epochs * len(data_loader)))

    trainer.run(data_loader, config.training.n_epochs)
        "WARNING: You have a CUDA device, so you should probably run with --cuda"
    )

###### Definition of variables ######
# Networks
encoder = Encoder(input_nc=input_nc)
decoder_A2B = Decoder(output_nc=output_nc)
decoder_B2A = Decoder(output_nc=output_nc)
netD_A = DiscriminatorNew(input_nc)
netD_B = DiscriminatorNew(output_nc)
loss_GAN_hist = []
loss_cycle_hist = []
loss_D_hist = []

if activate_cuda:
    encoder.cuda()
    decoder_A2B.cuda()
    decoder_B2A.cuda()
    netD_A.cuda()
    netD_B.cuda()

if os.path.isfile('output modified/encoder.pth'):
    encoder.load_state_dict('output modified/encoder.pth')
    decoder_A2B.load_state_dict(torch.load('output modified/decoder_A2B.pth'))
    decoder_B2A.load_state_dict(torch.load('output modified/decoder_B2A.pth'))
    netD_A.load_state_dict(torch.load('output modified/netD_A.pth'))
    netD_B.load_state_dict(torch.load('output modified/netD_B.pth'))
else:
    encoder.apply(weights_init_normal)
    decoder_A2B.apply(weights_init_normal)
    decoder_B2A.apply(weights_init_normal)
#                 nheads=args.nb_heads,
#                 alpha=args.alpha)
# else:
#     model = GAT(nfeat=3,
#                 nhid=args.hidden,
#                 nclass=12,
#                 dropout=args.dropout,
#                 nheads=args.nb_heads,
#                 alpha=args.alpha)
model = Encoder(output_size=(7, 7), spatial_scale=1.0, hidden=args.hidden, nclass=12,
                dropout=args.dropout, nb_heads=args.nb_heads,  alpha=args.alpha)
optimizer = optim.SGD(model.parameters(),
                       lr=args.lr)

if args.cuda:
    model.cuda()
    # features = features.cuda()
    # adj = adj.cuda()
    # labels = labels.cuda()
    # idx_train = idx_train.cuda()
    # idx_val = idx_val.cuda()
    # idx_test = idx_test.cuda()

# features, adj, labels = Variable(features), Variable(adj), Variable(labels)


def train(epoch, train_loader, val_loader, logger=None):
    t = time.time()
    # model.eval()
    model.train()
Example #11
0
def main(args):
    """
    Training and validation.
    """
    with open(args.word_map_file, 'rb') as f:
        word_map = pickle.load(f)


    #make choice wich ecoder to use
    encoder = Encoder(input_feature_dim=args.input_feature_dim,
                         encoder_hidden_dim=args.encoder_hidden_dim,
                         encoder_layer=args.encoder_layer,
                         rnn_unit=args.rnn_unit,
                         use_gpu=args.CUDA,
                         dropout_rate=args.dropout_rate
                         )

    #encoder = EncoderT(input_feature_dim=args.input_feature_dim,
     #                    encoder_hidden_dim=args.encoder_hidden_dim,
     #                    encoder_layer=args.encoder_layer,
     #                   rnn_unit=args.rnn_unit,
     #                    use_gpu=args.CUDA,
     #                    dropout_rate=args.dropout,
     #                    nhead=args.nhead
     #                    )

    decoder = DecoderWithAttention(attention_dim=args.attention_dim,
                                    embed_dim=args.emb_dim,
                                    decoder_dim=args.decoder_dim,
                                    vocab_size=len(word_map),
                                    dropout=args.dropout)
                                    
    if args.resume:
        encoder.load_state_dict(torch.load(args.encoder_path))
        decoder.load_state_dict(torch.load(args.decoder_path))


    encoder_parameter = [p for p in encoder.parameters() if p.requires_grad] # selecting every parameter.
    decoder_parameter = [p for p in decoder.parameters() if p.requires_grad]
   
    encoder_optimizer = torch.optim.Adam(encoder_parameter,lr=args.decoder_lr) #Adam selected
    decoder_optimizer = torch.optim.Adam(decoder_parameter,lr=args.decoder_lr)
    
    if args.CUDA:
        decoder = decoder.cuda()    
        encoder = encoder.cuda()

    if args.CUDA:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        criterion = nn.CrossEntropyLoss() #gewoon naar cpu dan
    
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(args.data_path, split='TRAIN'),
        batch_size=args.batch_size, shuffle=True, num_workers=args.workers, collate_fn=pad_collate_train,pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(args.data_path,split='VAL'),
        batch_size=args.batch_size, shuffle=False, num_workers=args.workers,collate_fn=pad_collate_train, pin_memory=True)

    # Epochs
    best_bleu4 = 0
    for epoch in range(args.start_epoch, args.epochs):
        
        losst = train(train_loader=train_loader,  ## deze los is de trainit_weight loss! 
              encoder = encoder,              
              decoder=decoder,
              criterion=criterion,  
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,           
              epoch=epoch,args=args)

        # One epoch's validation
        if epoch%1==0:
            lossv = validate(val_loader=val_loader,   
                            encoder=encoder,
                            decoder=decoder,
                            criterion=criterion,
                            best_bleu=best_bleu4,
                            args=args) 

        info = 'LOSST - {losst:.4f}, LOSSv - {lossv:.4f}\n'.format(
                losst=losst,
                lossv=lossv)


        with open(dev, "a") as f:    ## de los moet ook voor de validation 
            f.write(info)
            f.write("\n")  

        #Selecteren op basis van Bleu gaat als volgt:    
        #print('BLEU4: ' + bleu4)
        #print('best_bleu4 '+ best_bleu4)
        #if bleu4>best_bleu4:
        if epoch %3 ==0:
            save_checkpoint(epoch, encoder, decoder, encoder_optimizer,
                            decoder_optimizer, lossv)
Example #12
0
class AAETrainer(AbstractTrainer):
    def __init__(self, opt):
        super().__init__(opt)

        print('[info] Dataset:', self.opt.dataset)
        print('[info] Alhpa = ', self.opt.alpha)
        print('[info] Latent dimension = ', self.opt.latent_dim)

        self.opt = opt
        self.start_visdom()

    def start_visdom(self):
        self.vis = utils.Visualizer(env='Adversarial AutoEncoder Training',
                                    port=8888)

    def build_network(self):
        print('[info] Build the network architecture')
        self.encoder = Encoder(z_dim=self.opt.latent_dim)
        if self.opt.dataset == 'SMPL':
            num_verts = 6890
        elif self.opt.dataset == 'all_animals':
            num_verts = 3889
        self.decoder = Decoder(num_verts=num_verts, z_dim=self.opt.latent_dim)
        self.discriminator = Discriminator(input_dim=self.opt.latent_dim)

        self.encoder.cuda()
        self.decoder.cuda()
        self.discriminator.cuda()

    def build_optimizer(self):
        print('[info] Build the optimizer')
        self.optim_dis = optim.SGD(self.discriminator.parameters(),
                                   lr=self.opt.learning_rate)
        self.optim_AE = optim.Adam(itertools.chain(self.encoder.parameters(),
                                                   self.decoder.parameters()),
                                   lr=self.opt.learning_rate)

    def build_dataset_train(self):
        train_data = ACAPData(mode='train', name=self.opt.dataset)
        self.num_train_data = len(train_data)
        print('[info] Number of training samples = ', self.num_train_data)
        self.train_loader = torch.utils.data.DataLoader(
            train_data, batch_size=self.opt.batch_size, shuffle=True)

    def build_dataset_valid(self):
        valid_data = ACAPData(mode='valid', name=self.opt.dataset)
        self.num_valid_data = len(valid_data)
        print('[info] Number of validation samples = ', self.num_valid_data)
        self.valid_loader = torch.utils.data.DataLoader(valid_data,
                                                        batch_size=128,
                                                        shuffle=True)

    def build_losses(self):
        print('[info] Build the loss functions')
        self.mseLoss = torch.nn.MSELoss()
        self.ganLoss = torch.nn.BCELoss()

    def print_iteration_stats(self):
        """
        print stats at each iteration
        """
        print(
            '\r[Epoch %d] [Iteration %d/%d] enc = %f dis = %f rec = %f' %
            (self.epoch, self.iteration,
             int(self.num_train_data / self.opt.batch_size),
             self.enc_loss.item(), self.dis_loss.item(), self.rec_loss.item()),
            end='')

    def train_iteration(self):

        self.encoder.train()
        self.decoder.train()
        self.discriminator.train()

        x = self.data.cuda()

        z = self.encoder(x)
        ''' Discriminator '''
        # sample from N(0, I)
        z_real = Variable(torch.randn(z.size(0), z.size(1))).cuda()

        y_real = Variable(torch.ones(z.size(0))).cuda()
        dis_real_loss = self.ganLoss(
            self.discriminator(z_real).view(-1), y_real)

        y_fake = Variable(torch.zeros(z.size(0))).cuda()
        dis_fake_loss = self.ganLoss(self.discriminator(z).view(-1), y_fake)

        self.optim_dis.zero_grad()
        self.dis_loss = 0.5 * (dis_fake_loss + dis_real_loss)
        self.dis_loss.backward(retain_graph=True)
        self.optim_dis.step()
        self.dis_losses.append(self.dis_loss.item())
        ''' Autoencoder '''
        # Encoder hopes to generate latent vectors that are closed to prior.
        y_real = Variable(torch.ones(z.size(0))).cuda()
        self.enc_loss = self.ganLoss(self.discriminator(z).view(-1), y_real)

        # Decoder hopes to make the reconstruction as similar to input as possible.
        rec = self.decoder(z)
        self.rec_loss = self.mseLoss(rec, x)

        # There is a trade-off here:
        # Latent regularization V.S. Reconstruction quality
        self.EG_loss = self.opt.alpha * self.enc_loss + (
            1 - self.opt.alpha) * self.rec_loss

        self.optim_AE.zero_grad()
        self.EG_loss.backward()
        self.optim_AE.step()

        self.enc_losses.append(self.enc_loss.item())
        self.rec_losses.append(self.rec_loss.item())

        self.print_iteration_stats()
        self.increment_iteration()

    def train_epoch(self):

        self.reset_iteration()
        self.dis_losses = []
        self.enc_losses = []
        self.rec_losses = []
        for step, data in enumerate(self.train_loader):
            self.data = data
            self.train_iteration()

        self.dis_losses = torch.Tensor(self.dis_losses)
        self.dis_losses = torch.mean(self.dis_losses)

        self.enc_losses = torch.Tensor(self.enc_losses)
        self.enc_losses = torch.mean(self.enc_losses)

        self.rec_losses = torch.Tensor(self.rec_losses)
        self.rec_losses = torch.mean(self.rec_losses)

        self.vis.draw_line(win='Encoder Loss', x=self.epoch, y=self.enc_losses)
        self.vis.draw_line(win='Discriminator Loss',
                           x=self.epoch,
                           y=self.dis_losses)
        self.vis.draw_line(win='Reconstruction Loss',
                           x=self.epoch,
                           y=self.rec_losses)

    def valid_iteration(self):

        self.encoder.eval()
        self.decoder.eval()
        self.discriminator.eval()

        x = self.data.cuda()
        z = self.encoder(x)
        recon = self.decoder(z)

        # loss
        rec_loss = self.mseLoss(recon, x)
        self.rec_loss.append(rec_loss.item())
        self.increment_iteration()

    def valid_epoch(self):
        self.reset_iteration()
        self.rec_loss = []
        for step, data in enumerate(self.valid_loader):
            self.data = data
            self.valid_iteration()

        self.rec_loss = torch.Tensor(self.rec_loss)
        self.rec_loss = torch.mean(self.rec_loss)
        self.vis.draw_line(win='Valid reconstruction loss',
                           x=self.epoch,
                           y=self.rec_loss)

    def save_network(self):
        print("\n[info] saving net...")
        torch.save(self.encoder.state_dict(),
                   f"{self.opt.save_path}/Encoder.pth")
        torch.save(self.decoder.state_dict(),
                   f"{self.opt.save_path}/Decoder.pth")
        torch.save(self.discriminator.state_dict(),
                   f"{self.opt.save_path}/Discriminator.pth")
Example #13
0
                   default=128,
                   help='size of image height')
parse.add_argument('--img_width',
                   type=int,
                   default=128,
                   help='size of image width')
opt = parse.parse_args()
#init
Net_G = Generator(opt.z_dim, (3, opt.img_height, opt.img_width))
Net_E = Encoder(opt.z_dim)
D_VAE = Discriminator()
D_LR = Discriminator()
l1_loss = torch.nn.L1Loss()
#cuda
Net_G.cuda()
Net_E.cuda()
D_VAE.cuda()
D_LR.cuda()
l1_loss.cuda()

#weight_init
Net_G.apply(weights_init)
D_VAE.apply(weights_init)
D_LR.apply(weights_init)
#optimizer
optimizer_E = torch.optim.Adam(Net_E.parameters(),
                               lr=opt.lr,
                               betas=(opt.b1, opt.b2))
optimizer_G = torch.optim.Adam(Net_G.parameters(),
                               lr=opt.lr,
                               betas=(opt.b1, opt.b2))
Example #14
0
decoder = GridLSTMDecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(vocab),
                                       encoder_dim=512,
                                       dropout=dropout)

decoder.fine_tune_embeddings(True)
decoder = decoder.cuda(gpu2)

decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                   decoder.parameters()),
                                     lr=decoder_lr)
encoder = Encoder()
encoder.fine_tune(fine_tune_encoder)
encoder = encoder.cuda(gpu1)

encoder_optimizer = torch.optim.Adam(
    params=filter(lambda p: p.requires_grad, encoder.parameters()),
    lr=encoder_lr) if fine_tune_encoder else None

decoder_sched = torch.optim.lr_scheduler.CosineAnnealingLR(decoder_optimizer,
                                                           8,
                                                           eta_min=5e-6,
                                                           last_epoch=-1)
encoder_sched = torch.optim.lr_scheduler.CosineAnnealingLR(encoder_optimizer,
                                                           8,
                                                           eta_min=5e-6,
                                                           last_epoch=-1)

criterion = nn.CrossEntropyLoss().cuda(gpu2)
Example #15
0
def train(args):
    cfg_from_file(args.cfg)
    cfg.WORKERS = args.num_workers
    pprint.pprint(cfg)
    # set the seed manually
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    # define outputer
    outputer_train = Outputer(args.output_dir, cfg.IMAGETEXT.PRINT_EVERY,
                              cfg.IMAGETEXT.SAVE_EVERY)
    outputer_val = Outputer(args.output_dir, cfg.IMAGETEXT.PRINT_EVERY,
                            cfg.IMAGETEXT.SAVE_EVERY)
    # define the dataset
    split_dir, bshuffle = 'train', True

    # Get data loader
    imsize = cfg.TREE.BASE_SIZE * (2**(cfg.TREE.BRANCH_NUM - 1))
    train_transform = transforms.Compose([
        transforms.Scale(int(imsize * 76 / 64)),
        transforms.RandomCrop(imsize),
    ])
    val_transform = transforms.Compose([
        transforms.Scale(int(imsize * 76 / 64)),
        transforms.CenterCrop(imsize),
    ])
    if args.dataset == 'bird':
        train_dataset = ImageTextDataset(args.data_dir,
                                         split_dir,
                                         transform=train_transform,
                                         sample_type='train')
        val_dataset = ImageTextDataset(args.data_dir,
                                       'val',
                                       transform=val_transform,
                                       sample_type='val')
    elif args.dataset == 'coco':
        train_dataset = CaptionDataset(args.data_dir,
                                       split_dir,
                                       transform=train_transform,
                                       sample_type='train',
                                       coco_data_json=args.coco_data_json)
        val_dataset = CaptionDataset(args.data_dir,
                                     'val',
                                     transform=val_transform,
                                     sample_type='val',
                                     coco_data_json=args.coco_data_json)
    else:
        raise NotImplementedError

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.IMAGETEXT.BATCH_SIZE,
        shuffle=bshuffle,
        num_workers=int(cfg.WORKERS))
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=cfg.IMAGETEXT.BATCH_SIZE,
        shuffle=False,
        num_workers=1)
    # define the model and optimizer
    if args.raw_checkpoint != '':
        encoder, decoder = load_raw_checkpoint(args.raw_checkpoint)
    else:
        encoder = Encoder()
        decoder = DecoderWithAttention(
            attention_dim=cfg.IMAGETEXT.ATTENTION_DIM,
            embed_dim=cfg.IMAGETEXT.EMBED_DIM,
            decoder_dim=cfg.IMAGETEXT.DECODER_DIM,
            vocab_size=train_dataset.n_words)
        # load checkpoint
        if cfg.IMAGETEXT.CHECKPOINT != '':
            outputer_val.log("load model from: {}".format(
                cfg.IMAGETEXT.CHECKPOINT))
            encoder, decoder = load_checkpoint(encoder, decoder,
                                               cfg.IMAGETEXT.CHECKPOINT)

    encoder.fine_tune(False)
    # to cuda
    encoder = encoder.cuda()
    decoder = decoder.cuda()
    loss_func = torch.nn.CrossEntropyLoss()
    if args.eval:  # eval only
        outputer_val.log("only eval the model...")
        assert cfg.IMAGETEXT.CHECKPOINT != ''
        val_rtn_dict, outputer_val = validate_one_epoch(
            0, val_dataloader, encoder, decoder, loss_func, outputer_val)
        outputer_val.log("\n[valid]: {}\n".format(dict2str(val_rtn_dict)))
        return

    # define optimizer
    optimizer_encoder = torch.optim.Adam(encoder.parameters(),
                                         lr=cfg.IMAGETEXT.ENCODER_LR)
    optimizer_decoder = torch.optim.Adam(decoder.parameters(),
                                         lr=cfg.IMAGETEXT.DECODER_LR)
    encoder_lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer_encoder, step_size=10, gamma=cfg.IMAGETEXT.LR_GAMMA)
    decoder_lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer_decoder, step_size=10, gamma=cfg.IMAGETEXT.LR_GAMMA)
    print("train the model...")
    for epoch_idx in range(cfg.IMAGETEXT.EPOCH):
        # val_rtn_dict, outputer_val = validate_one_epoch(epoch_idx, val_dataloader, encoder,
        #         decoder, loss_func, outputer_val)
        # outputer_val.log("\n[valid] epoch: {}, {}".format(epoch_idx, dict2str(val_rtn_dict)))
        train_rtn_dict, outputer_train = train_one_epoch(
            epoch_idx, train_dataloader, encoder, decoder, optimizer_encoder,
            optimizer_decoder, loss_func, outputer_train)
        # adjust lr scheduler
        encoder_lr_scheduler.step()
        decoder_lr_scheduler.step()

        outputer_train.log("\n[train] epoch: {}, {}\n".format(
            epoch_idx, dict2str(train_rtn_dict)))
        val_rtn_dict, outputer_val = validate_one_epoch(
            epoch_idx, val_dataloader, encoder, decoder, loss_func,
            outputer_val)
        outputer_val.log("\n[valid] epoch: {}, {}\n".format(
            epoch_idx, dict2str(val_rtn_dict)))

        outputer_val.save_step({
            "encoder": encoder.state_dict(),
            "decoder": decoder.state_dict()
        })
    outputer_val.save({
        "encoder": encoder.state_dict(),
        "decoder": decoder.state_dict()
    })
# filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in src_encoder_dict}
# overwrite entries in the existing state dict
src_encoder_dict.update(pretrained_dict) 
# load the new state dict
src_encoder.load_state_dict(src_encoder_dict)
optimizer = optim.SGD(
    list(src_encoder.parameters()) + list(src_classifier.parameters()),
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay)

criterion = nn.CrossEntropyLoss()

if cuda: 
    src_encoder = src_encoder.cuda()
    src_classifier = src_classifier.cuda() 
    criterion = criterion.cuda() 

src_encoder.train()
src_classifier.train()
# begin train
for epoch in range(1, epochs+1):
    correct = 0
    for batch_idx, (src_data, label) in enumerate(src_train_loader):
        if cuda:
            src_data, label = src_data.cuda(), label.cuda()
        src_data, label = Variable(src_data), Variable(label)
        optimizer.zero_grad()
        src_feature = src_encoder(src_data)
        output = src_classifier(src_feature)