예제 #1
0
class Solver(utils.BaseSolver):
    def build(self):

        self.model = VAE(self.config)

        if self.config.cuda:
            self.model.cuda()
        print(self.model)

        # Build Optimizer (Training Only)
        if self.config.mode == 'train':
            self.optimizer = self.config.optimizer(
                self.model.parameters(),
                lr=self.config.learning_rate,
                betas=(self.config.beta1, 0.999))

            self.loss_function = layers.VAELoss(self.config.recon_loss)

    def train_step(self, images):
        # Reconstruct Images
        recon_images, mu, log_variance = self.model(images)
        # Calculate loss
        recon_loss, kl_div = self.loss_function(images, recon_images, mu,
                                                log_variance)
        return recon_loss, kl_div
예제 #2
0
def load_networks(isTraining=False):
    depth_net = DepthNetModel()
    color_net = ColorNetModel()
    d_net = VAE()
    if param.useGPU:
        depth_net.cuda()
        color_net.cuda()
        d_net.cuda()

    depth_optimizer = optim.Adam(depth_net.parameters(),
                                 lr=param.alpha,
                                 betas=(param.beta1, param.beta2),
                                 eps=param.eps)
    color_optimizer = optim.Adam(color_net.parameters(),
                                 lr=param.alpha,
                                 betas=(param.beta1, param.beta2),
                                 eps=param.eps)
    d_optimizer = optim.Adam(d_net.parameters())

    if isTraining:
        netFolder = param.trainNet  # 'TrainingData'
        netName, _, _ = get_folder_content(netFolder)
        net = []
        for target in sorted(netName):
            if target[-4:] == '.tar':
                net.append(target)
        if param.isContinue and net:
            tokens = net[0].split('-')[1].split('.')[0]
            param.startIter = int(tokens)
            checkpoint = torch.load(netFolder + '/' + net[0])
            depth_net.load_state_dict(checkpoint['depth_net'])
            color_net.load_state_dict(checkpoint['color_net'])
            d_net.load_state_dict(checkpoint['d_net'])
            depth_optimizer.load_state_dict(checkpoint['depth_optimizer'])
            color_optimizer.load_state_dict(checkpoint['color_optimizer'])
            d_optimizer.load_state_dict(checkpoint['d_optimizer'])
        else:
            param.isContinue = False

    else:
        netFolder = param.testNet
        netName, _, _ = get_folder_content(netFolder)
        net = []
        for target in sorted(netName):
            if target[-4:] == '.tar':
                net.append(target)
        checkpoint = torch.load(netFolder + '/' + net[0])
        depth_net.load_state_dict(checkpoint['depth_net'])
        color_net.load_state_dict(checkpoint['color_net'])
        d_net.load_state_dict(checkpoint['d_net'])
        depth_optimizer.load_state_dict(checkpoint['depth_optimizer'])
        color_optimizer.load_state_dict(checkpoint['color_optimizer'])
        d_optimizer.load_state_dict(checkpoint['d_optimizer'])

    return depth_net, color_net, d_net, depth_optimizer, color_optimizer, d_optimizer
예제 #3
0
def main():
    args = parse_arguments()
    hidden_size = 300
    embed_size = 50
    kld_weight = 0.05
    temperature = 0.9
    use_cuda = torch.cuda.is_available()

    print("[!] preparing dataset...")
    TEXT = data.Field(lower=True, fix_length=30)
    LABEL = data.Field(sequential=False)
    train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
    TEXT.build_vocab(train_data, max_size=250000)
    LABEL.build_vocab(train_data)
    train_iter, test_iter = data.BucketIterator.splits(
        (train_data, test_data), batch_size=args.batch_size, repeat=False)
    vocab_size = len(TEXT.vocab) + 2

    print("[!] Instantiating models...")
    encoder = EncoderRNN(vocab_size,
                         hidden_size,
                         embed_size,
                         n_layers=2,
                         dropout=0.5,
                         use_cuda=use_cuda)
    decoder = DecoderRNN(embed_size,
                         hidden_size,
                         vocab_size,
                         n_layers=2,
                         dropout=0.5,
                         use_cuda=use_cuda)
    vae = VAE(encoder, decoder)
    optimizer = optim.Adam(vae.parameters(), lr=args.lr)
    if use_cuda:
        print("[!] Using CUDA...")
        vae.cuda()

    best_val_loss = None
    for e in range(1, args.epochs + 1):
        train(e, vae, optimizer, train_iter, vocab_size, kld_weight,
              temperature, args.grad_clip, use_cuda, TEXT)
        val_loss = evaluate(vae, test_iter, vocab_size, kld_weight, use_cuda)
        print("[Epoch: %d] val_loss:%5.3f | val_pp:%5.2fS" %
              (e, val_loss, math.exp(val_loss)))

        # Save the model if the validation loss is the best we've seen so far.
        if not best_val_loss or val_loss < best_val_loss:
            print("[!] saving model...")
            if not os.path.isdir("snapshot"):
                os.makedirs("snapshot")
            torch.save(vae.state_dict(), './snapshot/vae_{}.pt'.format(e))
            best_val_loss = val_loss
예제 #4
0
class VAETrainer:
    def __init__(self, dataset):
        self.model = VAE()
        if config.USE_GPU:
            self.model.cuda()
        self, dataset = dataset
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
        self.train_loader = DataLoader(self.dataset(train=True))
        self.test_loader = DataLoader(self.dataset(train=False))

    def train(self):
        pass

    def test(self):
        pass
예제 #5
0
def train(data_loader, model_index, x_eval_train, loaded_model):
    ### Model Initiation
    if loaded_model:
        vae = VAE()
        vae.cuda()
        saved_state_dict = tor.load(loaded_model)
        vae.load_state_dict(saved_state_dict)
        vae.cuda()
    else:
        vae = VAE()
        vae = vae.cuda()

    loss_func = tor.nn.MSELoss().cuda()

    #optim = tor.optim.SGD(fcn.parameters(), lr=LR, momentum=MOMENTUM)
    optim = tor.optim.Adam(vae.parameters(), lr=LR)

    lr_step = StepLR(optim, step_size=LR_STEPSIZE, gamma=LR_GAMMA)

    ### Training
    for epoch in range(EPOCH):
        print("|Epoch: {:>4} |".format(epoch + 1))

        ### Training
        for step, (x_batch, y_batch) in enumerate(data_loader):
            print("Process: {}/{}".format(step,
                                          int(AVAILABLE_SIZE[0] / BATCHSIZE)),
                  end="\r")
            x = Variable(x_batch).cuda()
            y = Variable(y_batch).cuda()
            out, KLD = vae(x)
            recon_loss = loss_func(out.cuda(), y)
            loss = (recon_loss + KLD_LAMBDA * KLD)

            loss.backward()
            optim.step()
            lr_step.step()
            optim.zero_grad()

            if step % RECORD_JSON_PERIOD == 0:
                save_record(model_index, epoch, optim, recon_loss, KLD)
            if step % RECORD_PIC_PERIOD == 0:
                save_pic("output_{}".format(model_index), vaee, 3)
            if step % RECORD_MODEL_PERIOD == 0:
                tor.save(
                    vae.state_dict(),
                    os.path.join(MODEL_ROOT,
                                 "ave_model_{}.pkl".format(model_index)))
예제 #6
0
def extract(fs, idx, N):
    model = VAE()
    model.load_state_dict(
        torch.load(cfg.vae_save_ckpt,
                   map_location=lambda storage, loc: storage)['model'])
    model = model.cuda(idx)

    for n, f in enumerate(fs):
        data = np.load(f)
        imgs = data['sx'].transpose(0, 3, 1, 2)
        actions = data['ax']
        rewards = data['rx']
        dones = data['dx']
        x = torch.from_numpy(imgs).float().cuda(idx) / 255.0
        mu, logvar, _, z = model(x)
        save_path = "{}/{}".format(cfg.seq_extract_dir, f.split('/')[-1])

        np.savez_compressed(save_path,
                            mu=mu.detach().cpu().numpy(),
                            logvar=logvar.detach().cpu().numpy(),
                            dones=dones,
                            rewards=rewards,
                            actions=actions)

        if n % 10 == 0:
            print('Process %d: %5d / %5d' % (idx, n, N))
예제 #7
0
        if (epoch % save_epoch == 0) or (epoch == training_epochs - 1):
            torch.save(scan.state_dict(),
                       '{}/scan_epoch_{}.pth'.format(exp, epoch))


data_manager = DataManager()
data_manager.prepare()

dae = DAE()
vae = VAE()
scan = SCAN()
if use_cuda:
    dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth'))
    vae.load_state_dict(torch.load('save/vae/vae_epoch_2999.pth'))
    scan.load_state_dict(torch.load('save/scan/scan_epoch_1499.pth'))
    dae, vae, scan = dae.cuda(), vae.cuda(), scan.cuda()
else:
    dae.load_state_dict(
        torch.load('save/dae/dae_epoch_2999.pth',
                   map_location=lambda storage, loc: storage))
    vae.load_state_dict(
        torch.load('save/vae/vae_epoch_2999.pth',
                   map_location=lambda storage, loc: storage))
    scan.load_state_dict(
        torch.load(exp + '/' + opt.load,
                   map_location=lambda storage, loc: storage))

if opt.train:
    scan_optimizer = optim.Adam(scan.parameters(), lr=1e-4, eps=1e-8)
    train_scan(dae, vae, scan, data_manager, scan_optimizer)
예제 #8
0
def train_vae(args):
    hidden_size = 300
    embed_size = 50
    kld_start_inc = 2
    kld_weight = 0.05
    kld_max = 0.1
    kld_inc = 0.000002
    temperature = 0.9
    temperature_min = 0.5
    temperature_dec = 0.000002
    USE_CUDA = torch.cuda.is_available()
    print_loss_total = 0

    print("[!] preparing dataset...")
    TEXT = data.Field(lower=True, fix_length=30)
    LABEL = data.Field(sequential=False)
    train, test = datasets.IMDB.splits(TEXT, LABEL)
    TEXT.build_vocab(train, max_size=250000)
    LABEL.build_vocab(train)
    train_iter, test_iter = data.BucketIterator.splits(
        (train, test), batch_size=args.batch_size, repeat=False)
    vocab_size = len(TEXT.vocab) + 2

    print("[!] Instantiating models...")
    encoder = EncoderRNN(vocab_size,
                         hidden_size,
                         embed_size,
                         n_layers=1,
                         use_cuda=USE_CUDA)
    decoder = DecoderRNN(embed_size,
                         hidden_size,
                         vocab_size,
                         n_layers=2,
                         use_cuda=USE_CUDA)
    vae = VAE(encoder, decoder)
    optimizer = optim.Adam(vae.parameters(), lr=args.lr)
    vae.train()
    if USE_CUDA:
        print("[!] Using CUDA...")
        vae.cuda()

    for epoch in range(1, args.epochs + 1):
        for b, batch in enumerate(train_iter):
            x, y = batch.text, batch.label
            if USE_CUDA:
                x, y = x.cuda(), y.cuda()
            optimizer.zero_grad()

            m, l, z, decoded = vae(x, temperature)
            if temperature > temperature_min:
                temperature -= temperature_dec
            recon_loss = F.cross_entropy(decoded.view(-1, vocab_size),
                                         x.contiguous().view(-1))
            kl_loss = -0.5 * (2 * l - torch.pow(m, 2) -
                              torch.pow(torch.exp(l), 2) + 1)
            kl_loss = torch.clamp(kl_loss.mean(), min=0.2).squeeze()
            loss = recon_loss + kl_loss * kld_weight

            if epoch > 1 and kld_weight < kld_max:
                kld_weight += kld_inc

            loss.backward()
            ec = nn.utils.clip_grad_norm(vae.parameters(), args.grad_clip)
            optimizer.step()

            sys.stdout.write(
                '\r[%d] [loss] %.4f - recon_loss: %.4f - kl_loss: %.4f - kld-weight: %.4f - temp: %4f'
                % (b, loss.data[0], recon_loss.data[0], kl_loss.data[0],
                   kld_weight, temperature))
            print_loss_total += loss.data[0]
            if b % 200 == 0 and b != 0:
                print_loss_avg = print_loss_total / 200
                print_loss_total = 0
                print("\n[avg loss] - ", print_loss_avg)
                _, sample = decoded.data.cpu()[:, 0, :].topk(1)
                print("[ORI]: ",
                      " ".join([TEXT.vocab.itos[i] for i in x.data[:, 0]]))
                print("[GEN]: ",
                      " ".join([TEXT.vocab.itos[i] for i in sample.squeeze()]))
        torch.save(vae, './snapshot/vae_{}.pt'.format(epoch))
예제 #9
0
def main(args):
    conf = None
    with open(args.config, 'r') as config_file:
        config = yaml.load(config_file, Loader=yaml.FullLoader)
        conf = config['combine']
        model_params = config['model']
        preprocess_params = config['preprocessor']
    date_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())
    path = os.path.join(conf['save_path'], date_time)
    path = conf['save_path']

    model = VAE(model_params['roll_dim'], model_params['hidden_dim'],
                model_params['infor_dim'], model_params['time_step'], 12)

    model.load_state_dict(torch.load(conf['model_path']))
    if torch.cuda.is_available():
        print('Using: ',
              torch.cuda.get_device_name(torch.cuda.current_device()))
        model.cuda()
    else:
        print('CPU mode')
    model.eval()
    pitch_path = conf['p_path'] + ".txt"
    rhythm_path = conf['r_path'] + ".txt"
    #chord_path = conf['chord_path'] + ".txt"
    name1 = pitch_path.split("/")[-3]
    name2 = rhythm_path.split("/")[-3]
    name = name1 + "+" + name2 + ".mid"
    name2 = name1 + "+" + name2 + ".txt"

    pitch = np.loadtxt(pitch_path)
    print(pitch)
    rhythm = np.loadtxt(rhythm_path)
    print(rhythm)

    print("Importing " + name1 + " pitch and " + name2 + " rhythm")

    #line_graph(pitch,rhythm)
    #bar_graph(pitch,rhythm)

    pitch = torch.from_numpy(pitch).float()
    rhythm = torch.from_numpy(rhythm).float()
    recon = model.decoder(pitch, rhythm)

    recon = torch.squeeze(recon, 0)
    recon = mf._sampling(recon)
    recon = np.array(recon.cpu().detach().numpy())
    length = torch.sum(rhythm).int()
    recon = recon[:length]
    #打印生成的音符分布
    note = recon[:, :-2]
    note = np.nonzero(note)[1]
    note = np.bincount(note, minlength=34).astype(float)
    recon = mf.modify_pianoroll_dimentions(recon,
                                           preprocess_params['low_crop'],
                                           preprocess_params['high_crop'],
                                           "add")

    #bar_graph(pitch,rhythm)
    mf.numpy_to_midi(recon, 120, path, name,
                     preprocess_params['smallest_note'])

    #pitch_rhythm(recon,path,name2) # write pitch information

    print("combine succeed")
예제 #10
0
    data_dir, train=True, download=True, transform=transforms.ToTensor()),
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(datasets.MNIST(
    data_dir, train=False, download=True, transform=transforms.ToTensor()),
                                          batch_size=batch_size,
                                          shuffle=True)

torch.manual_seed(seed)
if use_gpu:
    torch.cuda.manual_seed(seed)

model = VAE()
if use_gpu:
    model.cuda()

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

loss_list = []
test_loss_list = []
for epoch in range(num_epochs + 1):
    # training
    train_loss = train(epoch, model, optimizer, train_loader)
    loss_list.append(train_loss)

    # test
    test_loss = test(epoch, model, test_loader)
    test_loss_list.append(test_loss)

    print('epoch [{}/{}], loss: {:.4f}, test_loss: {:.4f}'.format(
예제 #11
0
class TrainingModel(object):
    def __init__(self, args, config):

        self.__dict__.update(config)
        self.config = config
        random.seed(self.seed)
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)

        if use_cuda:
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)
            torch.cuda.set_device(args.gpu)

        #torch.backends.cudnn.benchmark = False
        #torch.backends.cudnn.deterministic = True

        self.message = args.m
        self.data_generator = DataGenerator(self.config)
        self.vocab_size = self.data_generator.vocab_size
        self.ent_size = self.data_generator.ent_size

        self.model_name = 'IERM'

        if args.m != "":
            self.saveModeladdr = './trainModel/checkpoint_%s_%s.pkl' % (
                self.model_name, args.m)
        else:
            self.saveModeladdr = './trainModel/' + args.save

        self.model = Ranker(self.vocab_size, self.ent_size, self.config)
        self.VAE_model = VAE(self.vocab_size, self.ent_size,
                             self.model.word_emb, self.model.ent_emb,
                             self.config)

        if use_cuda:
            self.model.cuda()
            self.VAE_model.cuda()

        vae_lr = self.config[
            'pretrain_lr'] if config['pretrain_step'] > 0 else config['vae_lr']
        self.vae_optimizer = getOptimizer(config['vae_optim'],
                                          self.VAE_model.parameters(),
                                          lr=vae_lr,
                                          betas=(0.99, 0.99))
        self.ranker_optimizer = getOptimizer(
            config['ranker_optim'],
            self.model.parameters(),
            lr=config['ranker_lr'],
            weight_decay=config['weight_decay'])

        vae_model_size = sum(p.numel() for p in self.VAE_model.parameters())
        ranker_size = sum(p.numel() for p in self.model.parameters())
        #print 'Model size: ', vae_model_size, ranker_size
        #exit(-1)
        if args.resume and os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #print checkpoint.keys()
            self.model.load_state_dict(checkpoint['rank_state_dict'])
            self.VAE_model.load_state_dict(checkpoint['vae_state_dict'])
            self.vae_optimizer.load_state_dict(checkpoint['vae_optimizer'])
            self.ranker_optimizer.load_state_dict(checkpoint['rank_optimizer'])
        else:
            print("Creating a new model")

        self.timings = defaultdict(list)  #record the loss iterations
        self.evaluator = rank_eval()
        self.epoch = 0
        self.step = 0

        self.kl_weight = 1

        if args.visual:
            self.config['visual'] = True
            self.writer = SummaryWriter('runs/' + args.m)
        else:
            self.config['visual'] = False
        self.reconstr_loss = nn.MSELoss()

    def add_values(self, iter, value_dict):
        for key in value_dict:
            self.writer.add_scalar(key, value_dict[key], iter)

    def adjust_learning_rate(self, optimizer, lr, decay_rate=.5):
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr * decay_rate

    def kl_anneal_function(self, anneal_function, step, k=0.0025, x0=2500):
        if anneal_function == 'logistic':
            return float(1 / (1 + np.exp(-k * (step - x0))))
        elif anneal_function == 'linear':
            return min(1, step / x0)

    def vae_loss(self, input_qw, reconstr_w, input_qe, reconstr_e, prior_mean,
                 prior_var, posterior_mean, posterior_var, posterior_log_var):
        # Reconstruction term
        if self.config['reconstruct'] != 'entity':
            input_qw_bow = to_bow(input_qw, self.vocab_size)
            input_qw_bow = Tensor2Varible(torch.tensor(input_qw_bow).float())
            #reconstr_w = torch.log_softmax(reconstr_w + 1e-10,dim=1)
            #RL_w = -torch.sum(input_qw_bow * reconstr_w , dim=1)
            #RL_w = self.reconstr_loss(reconstr_w,input_qw_bow)
            RL_w = -torch.sum(
                input_qw_bow * reconstr_w +
                (1 - input_qw_bow) * torch.log(1 - torch.exp(reconstr_w)),
                dim=1)
        else:
            RL_w = Tensor2Varible(torch.tensor([0]).float())
        if self.config['reconstruct'] != 'word':
            input_qe_bow = to_bow(input_qe, self.ent_size)
            input_qe_bow = Tensor2Varible(torch.tensor(input_qe_bow).float())
            #RL_e = -torch.sum(input_qe_bow * reconstr_e, dim=1)
            #RL_e = self.reconstr_loss(reconstr_e,input_qe_bow)
            RL_e = -torch.sum(
                input_qe_bow * reconstr_e +
                (1 - input_qe_bow) * torch.log(1 - torch.exp(reconstr_e)),
                dim=1)
        else:
            RL_e = Tensor2Varible(torch.tensor([0]).float())

        # KL term
        # var division term
        var_division = torch.sum(posterior_var / prior_var, dim=1)
        # diff means term
        diff_means = prior_mean - posterior_mean
        diff_term = torch.sum((diff_means * diff_means) / prior_var, dim=1)
        # logvar det division term
        logvar_det_division = \
            prior_var.log().sum() - posterior_log_var.sum(dim=1)
        # combine terms
        KL = 0.5 * (var_division + diff_term - self.model.intent_num +
                    logvar_det_division)

        loss = self.kl_weight * KL + RL_w + RL_e
        #loss = 0.001 * KL + RL_w + RL_e

        return loss.sum(), KL.sum(), RL_w.sum(), RL_e.sum()

    def pretraining(self):
        if self.pretrain_step <= 0:
            return

        train_start_time = time.time()
        data_reader = self.data_generator.pretrain_reader(self.pretrain_bs)
        total_loss = 0.
        total_KL_loss = 0.
        total_RLw_loss = 0.
        total_RLe_loss = 0.
        for step in xrange(self.pretrain_step):
            input_qw, input_qe = next(data_reader)
            #self.kl_weight = self.kl_anneal_function('logistic', step)
            topic_e, vae_loss, kl_loss, rl_w_loss, rl_e_loss = self.train_VAE(
                input_qw, input_qe)
            vae_loss.backward()
            torch.nn.utils.clip_grad_value_(
                self.VAE_model.parameters(),
                self.clip_grad)  # clip_grad_norm(, )

            self.vae_optimizer.step()

            vae_loss = vae_loss.data

            #print ('VAE loss: %.3f\tKL: %.3f\tRL_w:%.3f\tRL_e:%.3f' % (vae_loss, kl_loss, rl_w_loss, rl_e_loss))

            if torch.isnan(vae_loss):
                print("Got NaN cost .. skipping")
                exit(-1)
                continue

            #if self.config['visual']:
            #    self.add_values(step, {'vae_loss': vae_loss, 'kl_loss': kl_loss, 'rl_w_loss': rl_w_loss,
            #                          'rl_e_loss': rl_e_loss, 'kl_weight': self.kl_weight})

            total_loss += vae_loss
            total_KL_loss += kl_loss
            total_RLw_loss += rl_w_loss
            total_RLe_loss += rl_e_loss

            if step != 0 and step % self.pretrain_freq == 0:
                total_loss /= self.pretrain_freq
                total_KL_loss /= self.pretrain_freq
                total_RLw_loss /= self.pretrain_freq
                total_RLe_loss /= self.pretrain_freq
                print('Step: %d\t Elapsed:%.2f' %
                      (step, time.time() - train_start_time))
                print(
                    'Pretrain VAE loss: %.3f\tKL: %.3f\tRL_w:%.3f\tRL_e:%.3f' %
                    (total_loss, total_KL_loss, total_RLw_loss,
                     total_RLe_loss))
                if self.config['visual']:
                    self.add_values(
                        step, {
                            'vae_loss': total_loss,
                            'kl_loss': total_KL_loss,
                            'rl_w_loss': total_RLw_loss,
                            'rl_e_loss': total_RLe_loss,
                            'kl_weight': self.kl_weight
                        })
                total_loss = 0.
                total_KL_loss = 0.
                total_RLw_loss = 0.
                total_RLe_loss = 0.
                print '=============================================='
                #self.generate_beta_phi_3(show_topic_limit=5)

        self.save_checkpoint(message=self.message + '-pretraining')
        print('Pretraining end')
        #recovering the learning rate
        self.adjust_learning_rate(self.vae_optimizer, self.config['vae_lr'], 1)

    def trainIters(self, ):
        self.step = 0
        train_start_time = time.time()
        patience = self.patience

        best_ndcg10 = 0.0
        last_ndcg10 = 0.0

        data_reader = self.data_generator.pair_reader(self.batch_size)
        total_loss = 0.0
        total_rank_loss = 0.
        total_vae_loss = 0.
        total_KL_loss = 0.
        total_RLw_loss = 0.
        total_RLe_loss = 0.

        for step in xrange(self.steps):
            out = next(data_reader)
            input_qw, input_qe, input_dw_pos, input_de_pos, input_dw_neg, input_de_neg = out
            rank_loss, vae_total_loss, KL_loss, RL_w_loss, RL_e_loss \
                = self.train(input_qw,input_qe,input_dw_pos,input_de_pos,input_dw_neg,input_de_neg)

            cur_total_loss = rank_loss + vae_total_loss
            if torch.isnan(cur_total_loss):
                print("Got NaN cost .. skipping")
                continue
            self.step += 1
            total_loss += cur_total_loss
            total_rank_loss += rank_loss
            total_vae_loss += vae_total_loss
            total_KL_loss += KL_loss
            total_RLw_loss += RL_w_loss
            total_RLe_loss += RL_e_loss

            if self.eval_freq != -1 and self.step % self.eval_freq == 0:
                with torch.no_grad():
                    valid_performance = self.test(
                        valid_or_test='valid',
                        source=self.config['click_model'])
                    current_ndcg10 = valid_performance['ndcg@10']

                    if current_ndcg10 > best_ndcg10:
                        print 'Got better result, save to %s' % self.saveModeladdr
                        best_ndcg10 = current_ndcg10
                        patience = self.patience
                        self.save_checkpoint(message=self.message)

                        #self.generate_beta_phi_3(show_topic_limit=5)
                    elif current_ndcg10 <= last_ndcg10 * self.cost_threshold:
                        patience -= 1
                    last_ndcg10 = current_ndcg10

            if self.step % self.train_freq == 0:
                total_loss /= self.train_freq
                total_rank_loss /= self.train_freq
                total_vae_loss /= self.train_freq
                total_KL_loss /= self.train_freq
                total_RLw_loss /= self.train_freq
                total_RLe_loss /= self.train_freq

                self.timings['train'].append(total_loss)
                print('Step: %d\t Elapsed:%.2f' %
                      (step, time.time() - train_start_time))
                print(
                    'Train total loss: %.3f\tRank loss: %.3f\tVAE loss: %.3f' %
                    (total_loss, total_rank_loss, total_vae_loss))
                print('KL loss: %.3f\tRL W: %.3f\tRL E: %.3f' %
                      (total_KL_loss, total_RLw_loss, total_RLe_loss))
                print('Patience left: %d' % patience)

                if self.config['visual']:
                    self.add_values(
                        step, {
                            'Train vae_loss': total_loss,
                            'Train kl_loss': total_KL_loss,
                            'Train rl_w_loss': total_RLw_loss,
                            'Train rl_e_loss': total_RLe_loss,
                            'Train Rank loss': total_rank_loss
                        })

                total_loss = 0
                total_rank_loss = 0.
                total_vae_loss = 0.
                total_KL_loss = 0.
                total_RLw_loss = 0.
                total_RLe_loss = 0.

            if patience < 0:
                print 'patience runs out...'
                break

        print 'Patience___: ', patience
        print("All done, exiting...")

    def test(self, valid_or_test, source):

        predicted = []
        results = defaultdict(list)

        if valid_or_test == 'valid':
            is_test = False
            data_addr = self.valid_rank_addr
            data_source = self.data_generator.pointwise_reader_evaluation(
                data_addr, is_test=is_test, label_type=source)
        elif valid_or_test == 'ntcir13' or valid_or_test == 'ntcir14':
            is_test = True
            data_source = self.data_generator.pointwise_ntcir_generator(
                valid_or_test)
            source = 'HUMAN'
        else:
            is_test = True
            data_addr = self.test_rank_addr
            data_source = self.data_generator.pointwise_reader_evaluation(
                data_addr, is_test=is_test, label_type=source)
        start = time.clock()
        count = 0
        for out in data_source:
            (qid, dids, input_qw, input_qe, input_dw, input_de, gt_rels) = out
            gt_rels = map(lambda t: score2cutoff(source, t), gt_rels)
            rels_predicted = self.predict(input_qw, input_qe, input_dw,
                                          input_de).view(-1).cpu().numpy()

            result = self.evaluator.eval(gt_rels, rels_predicted)
            for did, gt, pred in zip(dids, gt_rels, rels_predicted):
                predicted.append((qid, did, pred, gt))

            for k, v in result.items():
                results[k].append(v)
            count += 1
        elapsed = (time.clock() - start)
        print('Elapsed:%.3f\tAvg:%.3f' % (elapsed, elapsed / count))
        performances = {}

        for k, v in results.items():
            performances[k] = np.mean(v)

        print '------Source: %s\tPerformance-------:' % source
        print 'Validating...' if valid_or_test == 'valid' else 'Testing'
        print 'Message: %s' % self.message
        print 'Source: %s' % source
        print performances

        if valid_or_test != 'valid':
            path = './results/' + self.message + '_' + valid_or_test + '_' + source
            if not os.path.exists(path):
                os.makedirs(path)
            out_file = open('%s/%s.predicted.txt' % (path, self.model_name),
                            'w')
            for qid, did, pred, gt in predicted:
                print >> out_file, '\t'.join([qid, did, str(pred), str(gt)])

        return performances

    def get_text(self, input, map_fun):
        text_list = []
        for element in input:
            if element == 0:
                break
            text_list.append(map_fun(element))
        return ' '.join(text_list)

    def generate_beta_phi_3(self, topK=10, show_topic_limit=-1):
        beta, phi = self.VAE_model.infer_topic_dis(topK)
        topics = defaultdict(list)
        topics_ents = defaultdict(list)
        show_topic_num = self.config[
            'intent_num'] if show_topic_limit == -1 else show_topic_limit

        for i in range(show_topic_num):
            idxs = beta[i]
            eidxs = phi[i]
            component_words = [
                self.data_generator.id2word[idx] for idx in idxs.cpu().numpy()
            ]
            component_ents = [
                self.data_generator.id2ent[self.data_generator.new2old[idx]]
                for idx in eidxs.cpu().numpy()
            ]
            topics[i] = component_words
            topics_ents[i] = component_ents

        print '--------Topic-Word-------'
        prefix = ('./topic/%s/' % args.m)
        if not os.path.exists(prefix):
            os.makedirs(prefix)
        outfile = open(prefix + 'topic-words.txt', 'w')
        for k in topics:
            print >> outfile, (str(k) + ' : ' + ' '.join(topics[k]))
            print >> outfile, (str(k) + ' : ' + ' '.join(topics_ents[k]))
        return topics, topics_ents

    def run_test_topic(self, out_file_name, topK, topicNum):
        topics_words, topics_ents = self.generate_beta_phi_3(topK)
        data_addr = self.test_rank_addr
        data_source = self.data_generator.pointwise_reader_evaluation(
            data_addr, is_test=True, label_type=self.config['click_model'])
        out_file = open(out_file_name, 'w')
        with torch.no_grad():
            self.VAE_model.eval()
            self.model.eval()
            for i, out in enumerate(data_source):
                (qid, dids, input_qw, input_qe, input_dw, input_de,
                 gt_rels) = out
                theta = self.VAE_model.get_theta(input_qw, input_qe)
                input_qw = input_qw[0]
                input_qe = input_qe[0]

                input_w = self.get_text(
                    input_qw, lambda w: self.data_generator.id2word[w])
                input_e = self.get_text(
                    input_qe, lambda e: self.data_generator.id2ent[
                        self.data_generator.new2old[e]])

                theta = theta[0].data.cpu().numpy()
                top_indices = np.argsort(theta)[::-1][:3]

                #print '========================='
                print >> out_file, 'Query: ', input_w
                print >> out_file, 'Entity: ', input_e
                for j, k in enumerate(top_indices):
                    ws = topics_words[k]
                    es = topics_ents[k]
                    print >> out_file, '%d Word Topic %d: %s' % (j, k,
                                                                 ' '.join(ws))
                    print >> out_file, '%d Entity Topic %d: %s' % (
                        j, k, ' '.join(es))

    def generate_topic_word_ent(self, out_file, topK=10):
        print 'Visualizing ...'
        data_addr = self.test_rank_addr
        data_source = self.data_generator.pointwise_reader_evaluation(
            data_addr, is_test=True, label_type=self.config['click_model'])
        out_file = open(out_file, 'w')
        with torch.no_grad():
            self.VAE_model.eval()
            self.model.eval()
            for i, out in enumerate(data_source):
                (input_qw, input_qe, input_dw, input_de, gt_rels) = out
                _, word_indices, ent_indices = self.VAE_model.get_topic_words(
                    input_qw, input_qe, topK=topK)
                word_indices = word_indices[0].data.cpu().numpy()
                ent_indices = ent_indices[0].data.cpu().numpy()

                #print 'ent_indices: ', ent_indices
                #print 'word_indices: ', word_indices
                input_qw = input_qw[0]
                input_qe = input_qe[0]

                input_w = self.get_text(
                    input_qw, lambda w: self.data_generator.id2word[w])
                input_e = self.get_text(
                    input_qe, lambda e: self.data_generator.id2ent[
                        self.data_generator.new2old[e]])
                reconstuct_w = self.get_text(
                    word_indices, lambda w: self.data_generator.id2word[w])
                reconstuct_e = self.get_text(
                    ent_indices, lambda e: self.data_generator.id2ent[
                        self.data_generator.new2old[e]])

                print >> out_file, ('%d: Word: %s\tRecons: %s' %
                                    (i + 1, input_w, reconstuct_w))
                print >> out_file, ('%d: Ent: %s\tRecons: %s' %
                                    (i + 1, input_e, reconstuct_e))

    def train_VAE(self, input_qw, input_qe):
        self.VAE_model.train()
        self.VAE_model.zero_grad()
        self.vae_optimizer.zero_grad()

        topic_embeddings, logPw, logPe, prior_mean, prior_variance,\
            poster_mu, poster_sigma, poster_log_sigma = self.VAE_model(input_qw,input_qe)

        vae_total_loss, KL, RL_w, RL_e = self.vae_loss(
            input_qw, logPw, input_qe, logPe, prior_mean, prior_variance,
            poster_mu, poster_sigma, poster_log_sigma)

        #vae_total_loss.backward(retain_graph=True)
        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        #torch.nn.utils.clip_grad_value_(self.VAE_model.parameters(), self.clip_grad)  # clip_grad_norm(, )
        #self.vae_optimizer.step()

        return topic_embeddings, vae_total_loss, KL.data, RL_w.data, RL_e.data

    def train(self, input_qw, input_qe, input_dw_pos, input_de_pos,
              input_dw_neg, input_de_neg):
        # Turn on training mode which enables dropout.
        self.model.train()
        self.model.zero_grad()
        self.ranker_optimizer.zero_grad()

        topic_embeddings, vae_total_loss, KL_loss, RL_w_loss, RL_e_loss = self.train_VAE(
            input_qw, input_qe)

        score_pos, orth_loss_1 = self.model(input_qw, input_qe, input_dw_pos,
                                            input_de_pos, topic_embeddings)
        score_neg, orth_loss_2 = self.model(input_qw, input_qe, input_dw_neg,
                                            input_de_neg, topic_embeddings)

        rank_loss = torch.sum(torch.clamp(1.0 - score_pos + score_neg, min=0))
        vae_weight = self.config['intent_lambda']

        orth_loss = (orth_loss_1 + orth_loss_2) / 2
        total_loss = rank_loss + vae_weight * vae_total_loss + orth_loss
        total_loss.backward()

        ## update parameters
        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_value_(self.VAE_model.parameters(),
                                        self.clip_grad)  # clip_grad_norm(, )
        torch.nn.utils.clip_grad_value_(self.model.parameters(),
                                        self.clip_grad)  #clip_grad_norm(, )

        self.ranker_optimizer.step()
        self.vae_optimizer.step()

        return rank_loss.data, vae_total_loss.data, KL_loss, RL_w_loss, RL_e_loss

    def predict(self, input_qw, input_qe, input_dw, input_de):
        # Turn on evaluation mode which disables dropout.
        with torch.no_grad():
            self.VAE_model.eval()
            self.model.eval()
            topic_embeddings = self.VAE_model(input_qw, input_qe)
            rels_predicted, _ = self.model(input_qw, input_qe, input_dw,
                                           input_de, topic_embeddings)

        return rels_predicted

    def save_checkpoint(self, message):
        filePath = os.path.join(self.saveModeladdr)
        #if not os.path.exists(filePath):
        #    os.makedirs(filePath)
        torch.save(
            {
                'vae_state_dict': self.VAE_model.state_dict(),
                'rank_state_dict': self.model.state_dict(),
                'vae_optimizer': self.vae_optimizer.state_dict(),
                'rank_optimizer': self.ranker_optimizer.state_dict()
            }, filePath)

    def get_embeddings(self):
        word_embeddings = self.model.word_emb.weight.detach().cpu().numpy()
        ent_embeddings = self.model.ent_emb.weight.detach().cpu().numpy()
        topic_embeddings = self.model.topic_embedding.detach().cpu().numpy()

        print 'Topic size: ', topic_embeddings.shape[0]
        cPickle.dump((word_embeddings, ent_embeddings, topic_embeddings),
                     open('./topic_analysis/w_e_t_embedding.pkl', 'w'))
        print 'saved'
        return
예제 #12
0
class Runner(object):
    def __init__(self,
                 hparams,
                 train_size: int,
                 class_weight: Optional[Tensor] = None):
        # model, criterion
        self.model = VAE()

        # optimizer and scheduler
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=hparams.learning_rate,
                                          eps=hparams.eps,
                                          weight_decay=hparams.weight_decay)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, **hparams.scheduler)
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
        # self.kld = nn.KLDivLoss(reduction='sum')
        # device
        device_for_summary = self.__init_device(hparams.device,
                                                hparams.out_device)

        # summary
        self.writer = SummaryWriter(logdir=hparams.logdir)
        # TODO: fill in ~~DUMMY~~INPUT~~SIZE~~
        path_summary = Path(self.writer.logdir, 'summary.txt')
        if not path_summary.exists():
            print_to_file(path_summary, summary, (self.model, (40, 11)),
                          dict(device=device_for_summary))

        # save hyperparameters
        path_hparam = Path(self.writer.logdir, 'hparams.txt')
        if not path_hparam.exists():
            print_to_file(path_hparam, hparams.print_params)

    def __init_device(self, device, out_device):
        if device == 'cpu':
            self.in_device = torch.device('cpu')
            self.out_device = torch.device('cpu')
            self.str_device = 'cpu'
            return 'cpu'

        # device type: List[int]
        if type(device) == int:
            device = [device]
        elif type(device) == str:
            device = [int(device[-1])]
        else:  # sequence of devices
            if type(device[0]) != int:
                device = [int(d[-1]) for d in device]

        self.in_device = torch.device(f'cuda:{device[0]}')

        if len(device) > 1:
            if type(out_device) == int:
                self.out_device = torch.device(f'cuda:{out_device}')
            else:
                self.out_device = torch.device(out_device)
            self.str_device = ', '.join([f'cuda:{d}' for d in device])

            self.model = nn.DataParallel(self.model,
                                         device_ids=device,
                                         output_device=self.out_device)

        else:
            self.out_device = self.in_device
            self.str_device = str(self.in_device)

        self.model.cuda(self.in_device)
        self.bce.cuda(self.out_device)  ##

        torch.cuda.set_device(self.in_device)

        return 'cuda'

    # Running model for train, test and validation.
    def run(self, dataloader, mode: str, epoch: int):
        self.model.train() if mode == 'train' else self.model.eval()
        if mode == 'test':
            state_dict = torch.load(Path(self.writer.logdir, f'{epoch}.pt'),
                                    map_location='cpu')
            if isinstance(self.model, nn.DataParallel):
                self.model.module.load_state_dict(state_dict)
            else:
                self.model.load_state_dict(state_dict)
            path_test_result = Path(self.writer.logdir, f'test_{epoch}')
            os.makedirs(path_test_result, exist_ok=True)
        else:
            path_test_result = None

        avg_loss = 0.
        y = []
        y_est = []
        pred_prob = []

        pbar = tqdm(dataloader,
                    desc=f'{mode} {epoch:3d}',
                    postfix='-',
                    dynamic_ncols=True)

        for i_batch, batch in enumerate(pbar):
            # data
            x = batch['batch_x']
            x = x.to(self.in_device)  # B, F, T

            # forward
            reconstruct_x, mu, logvar = self.model(x)

            # loss
            BCE = self.bce(reconstruct_x, x.view(-1, 440)).mean(dim=1)  # (B,)
            if mode != 'test':
                loss = torch.mean(
                    BCE - 0.5 *
                    torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1))
            else:
                loss = 0.

            if mode == 'train':
                # backward
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                loss = loss.item()

            elif mode == 'valid':
                loss = loss.item()

            else:
                y += batch['batch_y']
                y_est += (BCE < 0.5).int().tolist()
                pred_prob += BCE.tolist()

            pbar.set_postfix_str('')

            avg_loss += loss

        avg_loss = avg_loss / len(dataloader.dataset)

        y = np.array(y)
        y_est = np.array(y_est)
        pred_prob = np.array(pred_prob, dtype=np.float32)

        return avg_loss, (y, y_est, pred_prob)

    def step(self, valid_loss: float, epoch: int):
        """

        :param valid_loss:
        :param epoch:
        :return: test epoch or 0
        """
        # self.scheduler.step()
        self.scheduler.step(valid_loss)

        # print learning rate
        for param_group in self.optimizer.param_groups:
            self.writer.add_scalar('learning rate', param_group['lr'], epoch)

        if epoch % 5 == 0:
            torch.save((self.model.module.state_dict() if isinstance(
                self.model, nn.DataParallel) else self.model.state_dict(), ),
                       Path(hparams.logdir) / f'VAE_{epoch}.pt')
        return 0
예제 #13
0
      
    #if epoch % 100 == 99:
      #disentangle_check(session, vae, data_manager)

data_manager = DataManager()
data_manager.prepare()
dae = DAE()
vae = VAE()
if use_cuda:
  dae.load_state_dict('save/dae/dae_epoch_2999.pth')
else:
  dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth', map_location=lambda storage, loc: storage))

if opt.load != '':
  print('loading {}'.format(opt.load))
  if use_cuda:
    vae.load_state_dict(torch.load(exp+'/'+opt.load))
  else:
    vae.load_state_dict(torch.load(exp+'/'+opt.load, map_location=lambda storage, loc: storage))


  
if use_cuda: dae, vae = dae.cuda(), vae.cuda()

if opt.train:
  vae_optimizer = optim.Adam(vae.parameters(), lr=1e-4, eps=1e-8)
  train_vae(dae, vae, data_manager, vae_optimizer)


예제 #14
0
          z_mean2[ri][i] = z_m[i]

         
    z_mean2 = Variable(z_mean2)
    if use_cuda: z_mean2 = z_mean2.cuda()
    generated_xs_v = vae.decode(z_mean2)
    generated_xs = dae(generated_xs_v)
    file_name = "disentangle_img/check_z{0}.png".format(target_z_index)
    generated_xs = torch.transpose(generated_xs,2,3)
    generated_xs = torch.transpose(generated_xs,1,3)
    if use_cuda: hsv_image = generated_xs.data.cpu().numpy()
    else: hsv_image = generated_xs.data.numpy()
    print(hsv_image[0].shape)
    save_10_images(hsv_image, file_name)
    

data_manager = DataManager()
data_manager.prepare()
vae = VAE()
dae = DAE()
if use_cuda:
  dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth'))
  vae = vae.cuda()
  dae = dae.cuda()
  vae.load_state_dict(torch.load('save/vae/vae_epoch_2900.pth'))
else:
  dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth', map_location=lambda storage, loc: storage))
  vae.load_state_dict(torch.load('save/vae/vae_epoch_2900.pth', map_location=lambda storage, loc: storage))

disentangle_check(dae, vae, data_manager)
예제 #15
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--hidden",
                        '-hid',
                        type=int,
                        default=768,
                        help="hidden state dimension")
    parser.add_argument('--epochs',
                        '-e',
                        type=int,
                        default=5,
                        help="number of epochs")
    parser.add_argument('--learning_rate',
                        '-lr',
                        type=float,
                        default=1e-4,
                        help="learning rate")
    parser.add_argument('--grudim',
                        '-gd',
                        type=int,
                        default=1024,
                        help='dimension for gru layer')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        default=64,
                        help='input batch size for training')
    parser.add_argument('--name',
                        '-n',
                        type=str,
                        default='embedded',
                        help='tensorboard visual name')
    parser.add_argument('--decay',
                        '-d',
                        type=float,
                        default=-1,
                        help='learning rate decay: Gamma')
    parser.add_argument('--beta', type=float, default=0.1, help='beta for kld')
    parser.add_argument('--data',
                        type=int,
                        default=1000,
                        help='how many pieces of music to use')

    args = parser.parse_args()

    hidden_dim = args.hidden
    epochs = args.epochs
    gru_dim = args.grudim
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    decay = args.decay
    beta = args.beta
    data_num = args.data

    folder_name = "hid%d_e%d_gru%d_lr%.4f_batch%d_decay%.4f_beta%.2f_data%d" % (
        hidden_dim, epochs, gru_dim, learning_rate, batch_size, decay, beta,
        data_num)

    writer = SummaryWriter('../logs/{}'.format(folder_name))

    # load data
    file_list = find('*.npy', data_dir)
    f = np.load(data_dir + file_list[0])
    note_dim = f.shape[1]

    model = VAE(note_dim, gru_dim, hidden_dim, batch_size)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    if decay > 0:
        scheduler = MinExponentialLR(optimizer, gamma=decay, minimum=1e-5)
    step = 0

    if torch.cuda.is_available():
        print('Using: ',
              torch.cuda.get_device_name(torch.cuda.current_device()))
        model.cuda()
    else:
        print('CPU mode')

    for epoch in range(1, epochs):
        print("#" * 5, epoch, "#" * 5)
        batch_data = []
        batch_num = 0
        max_len = 0
        for i in range(len(file_list)):
            if i != 0 and i % batch_size == 0 or i == len(
                    file_list) - 1 or i == data_num:
                # create a batch by zero padding
                print("#" * 5, "batch", batch_num)
                if (i == len(file_list) - 1):
                    batch_size = len(file_list) % batch_size
                seq_lengths = LongTensor(list(map(len, batch_data)))
                print(seq_lengths.size())
                max_len = torch.max(seq_lengths).item()
                print("max_len:", max_len)
                batch = np.zeros((max_len, batch_size, note_dim))
                for j in range(len(batch_data)):
                    batch[:batch_data[j].shape[0], j, :] = batch_data[j]
                batch = torch.from_numpy(batch)
                seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
                batch = batch[:, perm_idx, :]
                step = train(model, batch, seq_lengths, step, optimizer, beta,
                             writer)
                # reset
                max_len = 0
                batch_data = []
                if decay > 0:
                    scheduler.step()
                batch_num += 1
            data = np.load(data_dir + file_list[i])
            batch_data.append(data)
            if i == data_num:
                break

        print("# saving params")
        param_name = "hid%d_e%d_gru%d_lr%.4f_batch%d_decay%.4f_beta%.2f_data%d_epoch%d" % (
            hidden_dim, epochs, gru_dim, learning_rate, batch_size, decay,
            beta, data_num, epoch)
        save_path = '../params/{}.pt'.format(param_name)
        if not os.path.exists('params') or not os.path.isdir('params'):
            os.mkdir('params')
        if torch.cuda.is_available():
            torch.save(model.cpu().state_dict(), save_path)
            model.cuda()
        else:
            torch.save(model.state_dict(), save_path)
        print('# Model saved!')

    writer.close()
예제 #16
0
파일: train.py 프로젝트: yz27/VAE_ISIC2018
if args.resume is not None:
    checkpoint = torch.load(args.resume,
                            map_location=lambda storage, loc: storage)
    print("checkpoint loaded!")
    print("val loss: {}\tepoch: {}\t".format(checkpoint['val_loss'],
                                             checkpoint['epoch']))

# model
model = VAE(args.image_size)
if args.resume is not None:
    model.load_state_dict(checkpoint['state_dict'])

# criterion
criterion = VAELoss(size_average=True, kl_weight=args.kl_weight)
if args.cuda is True:
    model = model.cuda()
    criterion = criterion.cuda()

# load data
train_loader, val_loader = load_vae_train_datasets(input_size=args.image_size,
                                                   data=args.data,
                                                   batch_size=args.batch_size)

# load optimizer and scheduler
opt = torch.optim.Adam(params=model.parameters(),
                       lr=args.lr,
                       betas=(0.9, 0.999))
if args.resume is not None and not args.reset_opt:
    opt.load_state_dict(checkpoint['optimizer'])

scheduler = torch.optim.lr_scheduler.MultiStepLR(opt,
예제 #17
0
                       '{}/recomb_epoch_{}.pth'.format(exp, epoch))


data_manager = DataManager()
data_manager.prepare()

dae = DAE()
vae = VAE()
scan = SCAN()
recomb = Recombinator()

if use_cuda:
    dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth'))
    vae.load_state_dict(torch.load('save/vae/vae_epoch_2999.pth'))
    scan.load_state_dict(torch.load('save/scan/scan_epoch_1499.pth'))
    dae, vae, scan, recomb = dae.cuda(), vae.cuda(), scan.cuda(), recomb.cuda()
else:
    dae.load_state_dict(
        torch.load('save/dae/dae_epoch_2999.pth',
                   map_location=lambda storage, loc: storage))
    vae.load_state_dict(
        torch.load('save/vae/vae_epoch_2999.pth',
                   map_location=lambda storage, loc: storage))
    scan.load_state_dict(
        torch.load('save/scan/scan_epoch_1499.pth',
                   map_location=lambda storage, loc: storage))
    recomb.load_state_dict(
        torch.load(exp + '/' + opt.load,
                   map_location=lambda storage, loc: storage))

if opt.train:
예제 #18
0
netImage = VAE(latent_variable_size=args.latent_dims, batchnorm=True)
netImage.load_state_dict(torch.load(args.pretrained_file))
print("Pre-trained model loaded from %s" % args.pretrained_file)

if args.conditional_adv:
    netClf = FC_Classifier(nz=args.latent_dims + 10)
    assert not args.conditional
else:
    netClf = FC_Classifier(nz=args.latent_dims)

if args.conditional:
    netCondClf = Simple_Classifier(nz=args.latent_dims)

if args.use_gpu:
    netRNA.cuda()
    netImage.cuda()
    netClf.cuda()
    if args.conditional:
        netCondClf.cuda()

# load data
genomics_dataset = RNA_Dataset(datadir="data/nCD4_gene_exp_matrices/")
image_dataset = NucleiDataset(datadir="data/nuclear_crops_all_experiments",
                              mode="test")

image_loader = torch.utils.data.DataLoader(image_dataset,
                                           batch_size=args.batch_size,
                                           drop_last=True,
                                           shuffle=True)
genomics_loader = torch.utils.data.DataLoader(genomics_dataset,
                                              batch_size=args.batch_size,
예제 #19
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        torch.manual_seed(self.args.seed)
        if self.args.cuda:
            torch.cuda.manual_seed(self.args.seed)

        kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('./data',
                           train=True,
                           download=True,
                           transform=transforms.ToTensor()),
            batch_size=self.args.batch_size,
            shuffle=True,
            **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('./data',
                           train=False,
                           transform=transforms.ToTensor()),
            batch_size=self.args.batch_size,
            shuffle=True,
            **kwargs)
        self.train_loader = train_loader
        self.test_loader = test_loader

        self.model = VAE()
        if self.args.cuda:
            self.model.cuda()

        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)

    def loss_function(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784))
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        KLD /= self.args.batch_size * 784
        return BCE + KLD

    def train_one_epoch(self, epoch):
        train_loader = self.train_loader
        args = self.args

        self.model.train()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = Variable(data)
            if args.cuda:
                data = data.cuda()
            self.optimizer.zero_grad()
            recon_batch, mu, logvar = self.model(data)
            loss = self.loss_function(recon_batch, data, mu, logvar)
            loss.backward()
            train_loss += loss.data[0]
            self.optimizer.step()
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.data[0] / len(data)))
        print('=====> Epoch: {} Average loss: {:.4f}'.format(
            epoch, train_loss / len(train_loader.dataset)))

    def test(self, epoch):
        test_loader = self.test_loader
        args = self.args

        self.model.eval()
        test_loss = 0
        for i, (data, _) in enumerate(test_loader):
            if args.cuda:
                data = data.cuda()
            data = Variable(data, volatile=True)
            recon_batch, mu, logvar = self.model(data)
            test_loss += self.loss_function(recon_batch, data, mu,
                                            logvar).data[0]
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([
                    data[:n],
                    recon_batch.view(args.batch_size, 1, 28, 28)[:n]
                ])
                fname = 'results/reconstruction_' + str(epoch) + '.png'
                save_image(comparison.data.cpu(), fname, nrow=n)

        test_loss /= len(test_loader.dataset)
        print('=====> Test set loss: {:.4f}'.format(test_loss))

    def train(self):
        args = self.args
        for epoch in range(1, args.epochs + 1):
            self.train_one_epoch(epoch)
            self.test(epoch)
            sample = Variable(torch.randn(64, 20))
            if args.cuda:
                sample = sample.cuda()
            sample = self.model.decode(sample).cpu()
            save_image(sample.data.view(64, 1, 28, 28),
                       './results/sample_' + str(epoch) + '.png')
예제 #20
0
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data', train=True, download=True, transform=transforms.ToTensor()),
                                               batch_size=args.batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data',
                       train=False,
                       download=True,
                       transform=transforms.ToTensor()),
        batch_size=args.batch_size,
        shuffle=True)

    vae = VAE(n_latents=args.n_latents)
    if args.cuda:
        vae.cuda()

    optimizer = optim.Adam(vae.parameters(), lr=args.lr)

    def train(epoch):
        vae.train()
        loss_meter = AverageMeter()

        for batch_idx, (data, _) in enumerate(train_loader):
            data = Variable(data)
            if args.cuda:
                data = data.cuda()
            optimizer.zero_grad()
            recon_batch, mu, logvar = vae(data)
            loss = loss_function(recon_batch, data, mu, logvar)
            loss.backward()
예제 #21
0
파일: main.py 프로젝트: mufeili/CVAE_PG
def main():
    time_str = time.strftime("%Y%m%d-%H%M%S")
    print('time_str: ', time_str)

    exp_count = 0

    if args.experiment == 'a|s':
        direc_name_ = '_'.join([args.env, args.experiment])
    else:
        direc_name_ = '_'.join(
            [args.env, args.experiment, 'bp2VAE',
             str(args.bp2VAE)])

    direc_name_exist = True

    while direc_name_exist:
        exp_count += 1
        direc_name = '/'.join([direc_name_, str(exp_count)])
        direc_name_exist = os.path.exists(direc_name)

    try:
        os.makedirs(direc_name)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

    if args.tensorboard_dir is None:
        logger = Logger('/'.join([direc_name, time_str]))
    else:
        logger = Logger(args.tensorboard_dir)

    env = gym.make(args.env)

    if args.wrapper:
        if args.video_dir is None:
            args.video_dir = '/'.join([direc_name, 'videos'])
        env = gym.wrappers.Monitor(env, args.video_dir, force=True)

    print('observation_space: ', env.observation_space)
    print('action_space: ', env.action_space)
    env.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.experiment == 'a|s':
        dim_x = env.observation_space.shape[0]
    elif args.experiment == 'a|z(s)' or args.experiment == 'a|z(s, s_next)' or \
            args.experiment == 'a|z(a_prev, s, s_next)':
        dim_x = args.z_dim

    policy = ActorCritic(input_size=dim_x,
                         hidden1_size=3 * dim_x,
                         hidden2_size=6 * dim_x,
                         action_size=env.action_space.n)

    if args.use_cuda:
        Tensor = torch.cuda.FloatTensor
        torch.cuda.manual_seed_all(args.seed)
        policy.cuda()
    else:
        Tensor = torch.FloatTensor

    policy_optimizer = optim.Adam(policy.parameters(), lr=args.policy_lr)

    if args.experiment != 'a|s':
        from util import ReplayBuffer, vae_loss_function

        dim_s = env.observation_space.shape[0]

        if args.experiment == 'a|z(s)' or args.experiment == 'a|z(s, s_next)':
            from model import VAE
            vae = VAE(input_size=dim_s,
                      hidden1_size=3 * args.z_dim,
                      hidden2_size=args.z_dim)

        elif args.experiment == 'a|z(a_prev, s, s_next)':
            from model import CVAE
            vae = CVAE(input_size=dim_s,
                       class_size=1,
                       hidden1_size=3 * args.z_dim,
                       hidden2_size=args.z_dim)

        if args.use_cuda:
            vae.cuda()
        vae_optimizer = optim.Adam(vae.parameters(), lr=args.vae_lr)

        if args.experiment == 'a|z(s)':
            from util import Transition_S2S as Transition
        elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)':
            from util import Transition_S2SNext as Transition

        buffer = ReplayBuffer(args.buffer_capacity, Transition)

        update_vae = True

    if args.experiment == 'a|s':
        from util import Record_S
    elif args.experiment == 'a|z(s)':
        from util import Record_S2S
    elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)':
        from util import Record_S2SNext

    def train_actor_critic(n):
        saved_info = policy.saved_info

        R = 0
        cum_returns_ = []

        for r in policy.rewards[::-1]:
            R = r + args.gamma * R
            cum_returns_.insert(0, R)

        cum_returns = Tensor(cum_returns_)
        cum_returns = (cum_returns - cum_returns.mean()) \
                      / (cum_returns.std() + np.finfo(np.float32).eps)
        cum_returns = Variable(cum_returns, requires_grad=False).unsqueeze(1)

        batch_info = SavedInfo(*zip(*saved_info))
        batch_log_prob = torch.cat(batch_info.log_prob)
        batch_value = torch.cat(batch_info.value)

        batch_adv = cum_returns - batch_value
        policy_loss = -torch.sum(batch_log_prob * batch_adv)
        value_loss = F.smooth_l1_loss(batch_value,
                                      cum_returns,
                                      size_average=False)

        policy_optimizer.zero_grad()
        total_loss = policy_loss + value_loss
        total_loss.backward()
        policy_optimizer.step()

        if args.use_cuda:
            logger.scalar_summary('value_loss', value_loss.data.cpu()[0], n)
            logger.scalar_summary('policy_loss', policy_loss.data.cpu()[0], n)

            all_value_loss.append(value_loss.data.cpu()[0])
            all_policy_loss.append(policy_loss.data.cpu()[0])
        else:
            logger.scalar_summary('value_loss', value_loss.data[0], n)
            logger.scalar_summary('policy_loss', policy_loss.data[0], n)

            all_value_loss.append(value_loss.data[0])
            all_policy_loss.append(policy_loss.data[0])

        del policy.rewards[:]
        del policy.saved_info[:]

    if args.experiment != 'a|s':

        def train_vae(n):

            train_times = (n // args.vae_update_frequency -
                           1) * args.vae_update_times

            for i in range(args.vae_update_times):
                train_times += 1

                sample = buffer.sample(args.batch_size)
                batch = Transition(*zip(*sample))
                state_batch = torch.cat(batch.state)

                if args.experiment == 'a|z(s)':
                    recon_batch, mu, log_var = vae.forward(state_batch)

                    mse_loss, kl_loss = vae_loss_function(
                        recon_batch,
                        state_batch,
                        mu,
                        log_var,
                        logger,
                        train_times,
                        kl_discount=args.kl_weight,
                        mode=args.experiment)

                elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)':
                    next_state_batch = Variable(torch.cat(batch.next_state),
                                                requires_grad=False)
                    predicted_batch, mu, log_var = vae.forward(state_batch)
                    mse_loss, kl_loss = vae_loss_function(
                        predicted_batch,
                        next_state_batch,
                        mu,
                        log_var,
                        logger,
                        train_times,
                        kl_discount=args.kl_weight,
                        mode=args.experiment)

                vae_loss = mse_loss + kl_loss

                vae_optimizer.zero_grad()
                vae_loss.backward()
                vae_optimizer.step()

                logger.scalar_summary('vae_loss', vae_loss.data[0],
                                      train_times)
                all_vae_loss.append(vae_loss.data[0])
                all_mse_loss.append(mse_loss.data[0])
                all_kl_loss.append(kl_loss.data[0])

    # To store cum_reward, value_loss and policy_loss from each episode
    all_cum_reward = []
    all_last_hundred_average = []
    all_value_loss = []
    all_policy_loss = []

    if args.experiment != 'a|s':
        # Store each vae_loss calculated
        all_vae_loss = []
        all_mse_loss = []
        all_kl_loss = []

    for episode in count(1):
        done = False
        state_ = torch.Tensor([env.reset()])
        cum_reward = 0

        if args.experiment == 'a|z(a_prev, s, s_next)':
            action = random.randint(0, 2)
            state_, reward, done, info = env.step(action)
            cum_reward += reward
            state_ = torch.Tensor([np.append(state_, action)])

        while not done:
            if args.experiment == 'a|s':
                state = Variable(state_, requires_grad=False)
            elif args.experiment == 'a|z(s)' or args.experiment == 'a|z(s, s_next)' \
                    or args.experiment == 'a|z(a_prev, s, s_next)':
                state_ = Variable(state_, requires_grad=False)
                mu, log_var = vae.encode(state_)

                if args.bp2VAE and update_vae:
                    state = vae.reparametrize(mu, log_var)
                else:
                    state = vae.reparametrize(mu, log_var).detach()

            action_ = policy.select_action(state)

            if args.use_cuda:
                action = action_.cpu()[0, 0]
            else:
                action = action_[0, 0]

            next_state_, reward, done, info = env.step(action)
            next_state_ = torch.Tensor([next_state_])
            cum_reward += reward

            if args.render:
                env.render()

            policy.rewards.append(reward)

            if args.experiment == 'a|z(s)':
                buffer.push(state_)
            elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)':
                if not done:
                    buffer.push(state_, next_state_)

            if args.experiment == 'a|z(a_prev, s, s_next)':
                next_state_ = torch.cat(
                    [next_state_, torch.Tensor([action])], 1)

            state_ = next_state_

        train_actor_critic(episode)
        last_hundred_average = sum(all_cum_reward[-100:]) / 100

        logger.scalar_summary('cum_reward', cum_reward, episode)
        logger.scalar_summary('last_hundred_average', last_hundred_average,
                              episode)

        all_cum_reward.append(cum_reward)
        all_last_hundred_average.append(last_hundred_average)

        if update_vae:
            if args.experiment != 'a|s' and episode % args.vae_update_frequency == 0:
                assert len(buffer) >= args.batch_size
                train_vae(episode)

            if len(all_vae_loss) > 1000:
                if abs(
                        sum(all_vae_loss[-500:]) / 500 -
                        sum(all_vae_loss[-1000:-500]) /
                        500) < args.vae_update_threshold:
                    update_vae = False

        if episode % args.log_interval == 0:
            print(
                'Episode {}\tLast cum return: {:5f}\t100-episodes average cum return: {:.2f}'
                .format(episode, cum_reward, last_hundred_average))

        if episode > args.num_episodes:
            print("100-episodes average cum return is now {} and "
                  "the last episode runs to {} time steps!".format(
                      last_hundred_average, cum_reward))
            env.close()
            torch.save(policy, '/'.join([direc_name, 'model']))

            if args.experiment == 'a|s':
                record = Record_S(
                    policy_loss=all_policy_loss,
                    value_loss=all_value_loss,
                    cum_reward=all_cum_reward,
                    last_hundred_average=all_last_hundred_average)
            elif args.experiment == 'a|z(s)':
                record = Record_S2S(
                    policy_loss=all_policy_loss,
                    value_loss=all_value_loss,
                    cum_reward=all_cum_reward,
                    last_hundred_average=all_last_hundred_average,
                    mse_recon_loss=all_mse_loss,
                    kl_loss=all_kl_loss,
                    vae_loss=all_vae_loss)
            elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)':
                record = Record_S2SNext(
                    policy_loss=all_policy_loss,
                    value_loss=all_value_loss,
                    cum_reward=all_cum_reward,
                    last_hundred_average=all_last_hundred_average,
                    mse_pred_loss=all_mse_loss,
                    kl_loss=all_kl_loss,
                    vae_loss=all_vae_loss)

            pickle.dump(record, open('/'.join([direc_name, 'record']), 'wb'))

            break