def main(args):

    train_set = DataLoader(dataset=FaceDataset('hw3_data/face', mode='train'),
                           batch_size=args.batch_size,
                           shuffle=True,
                           num_workers=args.num_workers)
    valid_set = DataLoader(dataset=FaceDataset('hw3_data/face', mode='test'),
                           batch_size=args.batch_size,
                           shuffle=False,
                           num_workers=args.num_workers)

    model = VAE(latent_dim=args.latent_dim)
    if args.model == 'VAE2':
        model = VAE2(latent_dim=args.latent_dim)
    criterion = VAE_loss(lambda_KL=args.lambda_KL)
    optimzer = optim.Adam(model.parameters(), lr=args.learning_rate)
    device = 'cuda'
    min_loss = 2
    for epoch in range(50):

        print('\nepoch: {}'.format(epoch))

        loss = train(train_set, model, optimzer, device, criterion)
        loss = validation(valid_set, model, device, criterion)
        if loss < min_loss:
            torch.save(
                model.state_dict(),
                '{}/lamda_{:.7f}-dim_{}-{}.pth'.format(args.save_folder,
                                                       args.lambda_KL,
                                                       args.latent_dim,
                                                       args.model))
            min_loss = loss
            print('Best epoch: {}'.format(epoch))
Пример #2
0
def run(args, kwargs):
    args.model_signature = str(datetime.datetime.now())[0:19]

    model_name = args.dataset_name + '_' + args.model_name + '_' + args.prior + '(K_' + str(args.number_components) + ')' + '_wu(' + str(args.warmup) + ')' + '_z1_' + str(args.z1_size) + '_z2_' + str(args.z2_size)

    # DIRECTORY FOR SAVING
    snapshots_path = 'snapshots/'
    dir = snapshots_path + args.model_signature + '_' + model_name +  '/'

    if not os.path.exists(dir):
        os.makedirs(dir)

    # LOAD DATA=========================================================================================================
    print('load data')

    # loading data
    train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)

    # CREATE MODEL======================================================================================================
    print('create model')
    # importing model
    if args.model_name == 'vae':
        from models.VAE import VAE
    elif args.model_name == 'hvae_2level':
        from models.HVAE_2level import VAE
    elif args.model_name == 'convhvae_2level':
        from models.convHVAE_2level import VAE
    elif args.model_name == 'pixelhvae_2level':
        from models.PixelHVAE_2level import VAE
    else:
        raise Exception('Wrong name of the model!')

    model = VAE(args)
    if args.cuda:
        model.cuda()

    optimizer = AdamNormGrad(model.parameters(), lr=args.lr)
#    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # ======================================================================================================================
    print(args)
    with open('vae_experiment_log.txt', 'a') as f:
        print(args, file=f)

    # ======================================================================================================================
    print('perform experiment')
    from utils.perform_experiment import experiment_vae
    experiment_vae(args, train_loader, val_loader, test_loader, model, optimizer, dir, model_name = args.model_name)
    # ======================================================================================================================
    print('-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-')
    with open('vae_experiment_log.txt', 'a') as f:
        print('-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-\n', file=f)
Пример #3
0
def run(args, kwargs):
    args.model_signature = str(datetime.datetime.now())[0:19]

    model_name = args.dataset_name + '_' + args.model_name + '_' + args.prior + '(K_' + str(args.number_components) + ')' + '_wu(' + str(args.warmup) + ')' + '_z1_' + str(args.z1_size) + '_z2_' + str(args.z2_size)

    # DIRECTORY FOR SAVING
    snapshots_path = 'snapshots/'
    dir = snapshots_path + args.model_signature + '_' + model_name +  '/'

    if not os.path.exists(dir):
        os.makedirs(dir)

    # LOAD DATA=========================================================================================================
    print('load data')

    # loading data
    train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)

    # CREATE MODEL======================================================================================================
    print('create model')
    # importing model
    if args.model_name == 'vae':
        from models.VAE import VAE
    elif args.model_name == 'hvae_2level':
        from models.HVAE_2level import VAE
    elif args.model_name == 'convhvae_2level':
        from models.convHVAE_2level import VAE
    elif args.model_name == 'pixelhvae_2level':
        from models.PixelHVAE_2level import VAE
    else:
        raise Exception('Wrong name of the model!')

    model = VAE(args)
    if args.cuda:
        model.cuda()

    optimizer = AdamNormGrad(model.parameters(), lr=args.lr)

    # ======================================================================================================================
    print(args)
    with open('vae_experiment_log.txt', 'a') as f:
        print(args, file=f)

    # ======================================================================================================================
    print('perform experiment')
    from utils.perform_experiment import experiment_vae
    experiment_vae(args, train_loader, val_loader, test_loader, model, optimizer, dir, model_name = args.model_name)
    # ======================================================================================================================
    print('-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-')
    with open('vae_experiment_log.txt', 'a') as f:
        print('-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-\n', file=f)
Пример #4
0
if __name__ == '__main__':
    # Load data
    trainset = CSL_Isolated_Openpose(skeleton_root=skeleton_root,list_file=train_file,
        length=length,is_normalize=False)
    devset = CSL_Isolated_Openpose(skeleton_root=skeleton_root,list_file=val_file,
        length=length,is_normalize=False)
    print("Dataset samples: {}".format(len(trainset)+len(devset)))
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    testloader = DataLoader(devset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    # Create model
    model = VAE(num_class,dropout=dropout).to(device)
    if checkpoint is not None:
        start_epoch, best_prec1 = resume_model(model,checkpoint)
    # Run the model parallelly
    if torch.cuda.device_count() > 1:
        print("Using {} GPUs".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)
    # Create loss criterion & optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Start Test
    print("Test Started".center(60, '#'))
    for epoch in range(start_epoch, start_epoch+1):
        # Test the model
        test_vae(model, criterion, testloader, device, epoch, log_interval, output_path, is_csl)

    print("Test Finished".center(60, '#'))


def main():
    parser = argparse.ArgumentParser(description='FAVAE anomaly detection')
    parser.add_argument('--obj', type=str, default='.')
    parser.add_argument('--data_type', type=str, default='mvtec')
    parser.add_argument('--data_path', type=str, default='')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='maximum training epochs')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--validation_ratio', type=float, default=0.2)
    parser.add_argument('--grayscale',
                        action='store_true',
                        help='color or grayscale input image')
    parser.add_argument('--img_resize', type=int, default=128)
    parser.add_argument('--crop_size', type=int, default=128)
    parser.add_argument('--do_aug',
                        action='store_true',
                        help='whether to do data augmentation before training')
    parser.add_argument('--augment_num', type=int, default=10000)
    parser.add_argument('--p_rotate',
                        type=float,
                        default=0.3,
                        help='probability to do image rotation')
    parser.add_argument('--rotate_angle_vari',
                        type=float,
                        default=15.0,
                        help='rotate image between [-angle, +angle]')
    parser.add_argument('--p_rotate_crop',
                        type=float,
                        default=1.0,
                        help='probability to crop inner rotated image')
    parser.add_argument('--p_horizonal_flip',
                        type=float,
                        default=0.3,
                        help='probability to do horizonal flip')
    parser.add_argument('--p_vertical_flip',
                        type=float,
                        default=0.3,
                        help='probability to do vertical flip')
    parser.add_argument('--kld_weight', type=float, default=1.0)
    parser.add_argument('--lr',
                        type=float,
                        default=0.005,
                        help='learning rate of Adam')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=0.00001,
                        help='decay of Adam')
    parser.add_argument('--seed', type=int, default=None, help='manual seed')
    args = parser.parse_args()

    args.p_crop = 1 if args.crop_size != args.img_resize else 0
    args.train_data_dir = args.data_path + '/' + args.obj + '/train/good'
    args.aug_dir = './train_patches/' + args.obj + '/train/good'

    args.input_channel = 1 if args.grayscale else 3

    if args.seed is None:
        args.seed = random.randint(1, 10000)
        random.seed(args.seed)
        torch.manual_seed(args.seed)
    if use_cuda:
        torch.cuda.manual_seed_all(args.seed)

    args.prefix = time_file_str()
    args.save_dir = './' + args.data_type + '/' + args.obj + '/vgg_feature' + '/seed_{}/'.format(
        args.seed)

    # data augmentation
    if not os.path.exists(args.aug_dir) and args.do_aug:
        os.makedirs(args.aug_dir)
        img_list = generate_image_list(args)
        augment_images(img_list, args)

    args.train_data_path = './train_patches'

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    log = open(
        os.path.join(args.save_dir,
                     'model_training_log_{}.txt'.format(args.prefix)), 'w')
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)

    # load model and dataset
    model = VAE(input_channel=args.input_channel, z_dim=100).to(device)
    teacher = models.vgg16(pretrained=True).to(device)
    for param in teacher.parameters():
        param.requires_grad = False

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    img_size = args.crop_size if args.img_resize != args.crop_size else args.img_resize
    kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
    train_dataset = MVTecDataset(args.train_data_path,
                                 class_name=args.obj,
                                 is_train=True,
                                 resize=img_size)
    img_nums = len(train_dataset)
    valid_num = int(img_nums * args.validation_ratio)
    train_num = img_nums - valid_num
    train_data, val_data = torch.utils.data.random_split(
        train_dataset, [train_num, valid_num])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=32,
                                             shuffle=False,
                                             **kwargs)

    test_dataset = MVTecDataset(args.data_path,
                                class_name=args.obj,
                                is_train=False,
                                resize=img_size)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=32,
                                              shuffle=True,
                                              **kwargs)

    # fetch fixed data for debugging
    x_normal_fixed, _, _ = iter(val_loader).next()
    x_normal_fixed = x_normal_fixed.to(device)

    x_test_fixed, _, _ = iter(test_loader).next()
    x_test_fixed = x_test_fixed.to(device)

    # start training
    save_name = os.path.join(args.save_dir,
                             '{}_{}_model.pt'.format(args.obj, args.prefix))
    early_stop = EarlyStop(patience=20, save_name=save_name)
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(1, args.epochs + 1):
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)
        print_log(
            ' {:3d}/{:3d} ----- [{:s}] {:s}'.format(epoch, args.epochs,
                                                    time_string(), need_time),
            log)
        train(args, model, teacher, epoch, train_loader, optimizer, log)
        val_loss = val(args, model, teacher, epoch, val_loader, log)

        if (early_stop(val_loss, model, optimizer, log)):
            break

        if epoch % 10 == 0:
            save_sample = os.path.join(args.save_dir,
                                       '{}val-images.jpg'.format(epoch))
            save_sample2 = os.path.join(args.save_dir,
                                        '{}test-images.jpg'.format(epoch))
            save_snapshot(x_normal_fixed, x_test_fixed, model, save_sample,
                          save_sample2, log)

        epoch_time.update(time.time() - start_time)
        start_time = time.time()
    log.close()
Пример #6
0
def run(args, kwargs):
    args.model_signature = str(datetime.datetime.now())[0:19]

    model_name = args.dataset_name + '_' + args.model_name + '_' + args.prior + '(M_' + str(
        args.M) + ')' + '(F_' + str(args.F) + ')' + '_wu(' + str(
            args.warmup) + ')' + '_z1_' + str(args.z1_size) + '_hidden_' + str(
                args.number_hidden) + '_ksi_' + str(args.ksi)

    if args.FI is True:
        model_name += '_FI'
    else:
        args.F = 0

    if args.MI is True:
        model_name += '_MI'
    else:
        args.M = 0

    # DIRECTORY FOR SAVING
    #snapshots_path = 'snapshots/'

    snapshots_path = args.snapshot_dir
    if not os.path.exists(snapshots_path):
        os.makedirs(snapshots_path)

    dir = snapshots_path + args.model_signature + '_' + model_name + '/'

    if not os.path.exists(dir):
        os.makedirs(dir)

    # LOAD DATA=========================================================================================================
    print('load data')

    # loading data
    train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)

    # CREATE MODEL======================================================================================================
    print('create model')
    if args.dataset_name == "celeba":
        args.model_name = "celebavae"
    # importing model
    if args.model_name == 'vae':
        from models.VAE import VAE
    elif args.model_name == 'hvae_2level':
        from models.HVAE_2level import VAE
    elif args.model_name == 'convhvae_2level':
        from models.convHVAE_2level import VAE
    elif args.model_name == 'convvae':
        from models.convVAE import VAE
    elif args.model_name == 'pixelhvae_2level':
        from models.PixelHVAE_2level import VAE
    elif args.model_name == 'pixelvae':
        from models.pixelVAE import VAE
    elif args.model_name == 'iaf_vae':
        from models.VAE_ccLinIAF import VAE
    elif args.model_name == 'rev_vae':
        from models.REV_VAE import VAE
    elif args.model_name == 'rev_pixelvae':
        from models.REV_pixelVAE import VAE
    elif args.model_name == 'celebavae':
        from models.CelebaVAE import VAE
    else:
        raise Exception('Wrong name of the model!')

    model = VAE(args)
    if args.cuda:
        model.cuda()

    optimizer = AdamNormGrad(model.parameters(), lr=args.lr)

    # ======================================================================================================================
    print(args)
    with open(dir + 'vae_experiment_log.txt', 'a') as f:
        print(args, file=f)

    # ======================================================================================================================
    print('perform experiment')
    from utils.perform_experiment import experiment_vae
    experiment_vae(args,
                   train_loader,
                   val_loader,
                   test_loader,
                   model,
                   optimizer,
                   dir,
                   model_name=args.model_name)
    # ======================================================================================================================
    print(
        '-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-'
    )
    with open(dir + '/vae_experiment_log.txt', 'a') as f:
        print(
            '-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-\n',
            file=f)
Пример #7
0
def run(args, kwargs):

    # LOAD DATA=========================================================================================================
    print('load data')

    checkpoints_dir, results_dir = create_dirNames(args)

    # loading data
    train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)

    # CREATE MODEL======================================================================================================
    print('create model')
    # importing model
    if args.model_name == 'vae':
        from models.VAE import VAE
        model = VAE(args)
    elif args.model_name == 'conv_vae':
        from models.conv_vae import VAE
        model = VAE(args)
    elif args.model_name == 'hvae_2level':
        from models.HVAE_2level import VAE
        model = VAE(args)
    elif args.model_name == 'convhvae_2level':
        from models.convHVAE_2level import VAE
        model = VAE(args)
    elif args.model_name == 'convhvae_2level-smim':
        from models.convHVAE_2level import SymMIM as VAE
        model = VAE(args)
    elif args.model_name == 'MLP_wae':
        from models.MLP_wae import WAE
        model = WAE(args)
    elif args.model_name == 'conv_wae':
        from models.conv_wae import WAE
        model = WAE(args)
    else:
        raise Exception('Wrong name of the model!')

    #model = VAE(args)
    if args.cuda:
        model.cuda()

    optimizer = AdamNormGrad(model.parameters(),
                             lr=args.lr,
                             betas=(args.beta1, 0.999))

    # ======================================================================================================================
    print(args)

    # ======================================================================================================================
    print('perform experiment')

    if args.trainflag:
        if args.model_name == 'MLP_wae' or args.model_name == 'conv_wae' or args.model_name == 'Pixel_wae' or args.model_name == 'conv_wae_2level':
            train_wae(args, train_loader, val_loader, model, optimizer,
                      checkpoints_dir, results_dir)
        else:
            train_vae(args, train_loader, val_loader, test_loader, model,
                      optimizer, checkpoints_dir, results_dir)

    if args.testflag:
        test(args, train_loader, test_loader, model, checkpoints_dir,
             results_dir)
class MFEC:
    def __init__(self, env, args, device='cpu'):
        """
        Instantiate an MFEC Agent
        ----------
        env: gym.Env
            gym environment to train on
        args: args class from argparser
            args are from from train.py: see train.py for help with each arg
        device: string
            'cpu' or 'cuda:0' depending on use_cuda flag from train.py
        """
        self.environment_type = args.environment_type
        self.env = env
        self.actions = range(self.env.action_space.n)
        self.frames_to_stack = args.frames_to_stack
        self.Q_train_algo = args.Q_train_algo
        self.use_Q_max = args.use_Q_max
        self.force_knn = args.force_knn
        self.weight_neighbors = args.weight_neighbors
        self.delta = args.delta
        self.device = device
        self.rs = np.random.RandomState(args.seed)

        # Hyperparameters
        self.epsilon = args.initial_epsilon
        self.final_epsilon = args.final_epsilon
        self.epsilon_decay = args.epsilon_decay
        self.gamma = args.gamma
        self.lr = args.lr
        self.q_lr = args.q_lr

        # Autoencoder for state embedding network
        self.vae_batch_size = args.vae_batch_size  # batch size for training VAE
        self.vae_epochs = args.vae_epochs  # number of epochs to run VAE
        self.embedding_type = args.embedding_type
        self.SR_embedding_type = args.SR_embedding_type
        self.embedding_size = args.embedding_size
        self.in_height = args.in_height
        self.in_width = args.in_width

        if self.embedding_type == 'VAE':
            self.vae_train_frames = args.vae_train_frames
            self.vae_loss = VAELoss()
            self.vae_print_every = args.vae_print_every
            self.load_vae_from = args.load_vae_from
            self.vae_weights_file = args.vae_weights_file
            self.vae = VAE(self.frames_to_stack, self.embedding_size,
                           self.in_height, self.in_width)
            self.vae = self.vae.to(self.device)
            self.optimizer = get_optimizer(args.optimizer,
                                           self.vae.parameters(), self.lr)
        elif self.embedding_type == 'random':
            self.projection = self.rs.randn(
                self.embedding_size, self.in_height * self.in_width *
                self.frames_to_stack).astype(np.float32)
        elif self.embedding_type == 'SR':
            self.SR_train_algo = args.SR_train_algo
            self.SR_gamma = args.SR_gamma
            self.SR_epochs = args.SR_epochs
            self.SR_batch_size = args.SR_batch_size
            self.n_hidden = args.n_hidden
            self.SR_train_frames = args.SR_train_frames
            self.SR_filename = args.SR_filename
            if self.SR_embedding_type == 'random':
                self.projection = np.random.randn(
                    self.embedding_size,
                    self.in_height * self.in_width).astype(np.float32)
                if self.SR_train_algo == 'TD':
                    self.mlp = MLP(self.embedding_size, self.n_hidden)
                    self.mlp = self.mlp.to(self.device)
                    self.loss_fn = nn.MSELoss(reduction='mean')
                    params = self.mlp.parameters()
                    self.optimizer = get_optimizer(args.optimizer, params,
                                                   self.lr)

        # QEC
        self.max_memory = args.max_memory
        self.num_neighbors = args.num_neighbors
        self.qec = QEC(self.actions, self.max_memory, self.num_neighbors,
                       self.use_Q_max, self.force_knn, self.weight_neighbors,
                       self.delta, self.q_lr)

        #self.state = np.empty(self.embedding_size, self.projection.dtype)
        #self.action = int
        self.memory = []
        self.print_every = args.print_every
        self.episodes = 0

    def choose_action(self, values):
        """
        Choose epsilon-greedy policy according to Q-estimates
        """
        # Exploration
        if self.rs.random_sample() < self.epsilon:
            self.action = self.rs.choice(self.actions)

        # Exploitation
        else:
            best_actions = np.argwhere(values == np.max(values)).flatten()
            self.action = self.rs.choice(best_actions)

        return self.action

    def TD_update(self, prev_embedding, prev_action, reward, values, time):
        # On-policy value estimate of current state (epsiloln-greedy)
        # Expected Sarsa
        v_t = (1 -
               self.epsilon) * np.max(values) + self.epsilon * np.mean(values)
        value = reward + self.gamma * v_t
        self.qec.update(prev_embedding, prev_action, value, time - 1)

    def MC_update(self):
        value = 0.0
        for _ in range(len(self.memory)):
            experience = self.memory.pop()
            value = value * self.gamma + experience["reward"]
            self.qec.update(
                experience["state"],
                experience["action"],
                value,
                experience["time"],
            )

    def add_to_memory(self, state_embedding, action, reward, time):
        self.memory.append({
            "state": state_embedding,
            "action": action,
            "reward": reward,
            "time": time,
        })

    def run_episode(self):
        """
        Train an MFEC agent for a single episode:
            Interact with environment
            Perform update
        """
        self.episodes += 1
        RENDER_SPEED = 0.04
        RENDER = False

        episode_frames = 0
        total_reward = 0
        total_steps = 0

        # Update epsilon
        if self.epsilon > self.final_epsilon:
            self.epsilon = self.epsilon * self.epsilon_decay

        #self.env.seed(random.randint(0, 1000000))
        state = self.env.reset()
        if self.environment_type == 'fourrooms':
            fewest_steps = self.env.shortest_path_length(self.env.state)
        done = False
        time = 0
        while not done:
            time += 1
            if self.embedding_type == 'random':
                state = np.array(state).flatten()
                state_embedding = np.dot(self.projection, state)
            elif self.embedding_type == 'VAE':
                state = torch.tensor(state).permute(2, 0, 1)  #(H,W,C)->(C,H,W)
                state = state.unsqueeze(0).to(self.device)
                with torch.no_grad():
                    mu, logvar = self.vae.encoder(state)
                    state_embedding = torch.cat([mu, logvar], 1)
                    state_embedding = state_embedding.squeeze()
                    state_embedding = state_embedding.cpu().numpy()
            elif self.embedding_type == 'SR':
                if self.SR_train_algo == 'TD':
                    state = np.array(state).flatten()
                    state_embedding = np.dot(self.projection, state)
                    with torch.no_grad():
                        state_embedding = self.mlp(
                            torch.tensor(state_embedding)).cpu().numpy()
                elif self.SR_train_algo == 'DP':
                    s = self.env.state
                    state_embedding = self.true_SR_dict[s]
            state_embedding = state_embedding / np.linalg.norm(state_embedding)
            if RENDER:
                self.env.render()
                time.sleep(RENDER_SPEED)

            # Get estimated value of each action
            values = [
                self.qec.estimate(state_embedding, action)
                for action in self.actions
            ]

            action = self.choose_action(values)
            state, reward, done, _ = self.env.step(action)
            if self.Q_train_algo == 'MC':
                self.add_to_memory(state_embedding, action, reward, time)
            elif self.Q_train_algo == 'TD':
                if time > 1:
                    self.TD_update(prev_embedding, prev_action, prev_reward,
                                   values, time)
            prev_reward = reward
            prev_embedding = state_embedding
            prev_action = action
            total_reward += reward
            total_steps += 1
            episode_frames += self.env.skip

        if self.Q_train_algo == 'MC':
            self.MC_update()
        if self.episodes % self.print_every == 0:
            print("KNN usage:", np.mean(self.qec.knn_usage))
            self.qec.knn_usage = []
            print("Proportion of replace:", np.mean(self.qec.replace_usage))
            self.qec.replace_usage = []
        if self.environment_type == 'fourrooms':
            n_extra_steps = total_steps - fewest_steps
            return n_extra_steps, episode_frames, total_reward
        else:
            return episode_frames, total_reward

    def warmup(self):
        """
        Collect 1 million frames from random policy and train VAE
        """
        if self.embedding_type == 'VAE':
            if self.load_vae_from is not None:
                self.vae.load_state_dict(torch.load(self.load_vae_from))
                self.vae = self.vae.to(self.device)
            else:
                # Collect 1 million frames from random policy
                print("Generating dataset to train VAE from random policy")
                vae_data = []
                state = self.env.reset()
                total_frames = 0
                while total_frames < self.vae_train_frames:
                    action = random.randint(0, self.env.action_space.n - 1)
                    state, reward, done, _ = self.env.step(action)
                    vae_data.append(state)
                    total_frames += self.env.skip
                    if done:
                        state = self.env.reset()
                # Dataset, Dataloader for 1 million frames
                vae_data = torch.tensor(
                    vae_data
                )  # (N x H x W x C) - (1mill/skip X 84 X 84 X frames_to_stack)
                vae_data = vae_data.permute(0, 3, 1, 2)  # (N x C x H x W)
                vae_dataset = TensorDataset(vae_data)
                vae_dataloader = DataLoader(vae_dataset,
                                            batch_size=self.vae_batch_size,
                                            shuffle=True)
                # Training loop
                print("Training VAE")
                self.vae.train()
                for epoch in range(self.vae_epochs):
                    train_loss = 0
                    for batch_idx, batch in enumerate(vae_dataloader):
                        batch = batch[0].to(self.device)
                        self.optimizer.zero_grad()
                        recon_batch, mu, logvar = self.vae(batch)
                        loss = self.vae_loss(recon_batch, batch, mu, logvar)
                        train_loss += loss.item()
                        loss.backward()
                        self.optimizer.step()
                        if batch_idx % self.vae_print_every == 0:
                            msg = 'VAE Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                                epoch, batch_idx * len(batch),
                                len(vae_dataloader.dataset),
                                loss.item() / len(batch))
                            print(msg)
                    print('====> Epoch {} Average loss: {:.4f}'.format(
                        epoch, train_loss / len(vae_dataloader.dataset)))
                    if self.vae_weights_file is not None:
                        torch.save(self.vae.state_dict(),
                                   self.vae_weights_file)
            self.vae.eval()
        elif self.embedding_type == 'SR':
            if self.SR_embedding_type == 'random':
                if self.SR_train_algo == 'TD':
                    total_frames = 0
                    transitions = []
                    while total_frames < self.SR_train_frames:
                        observation = self.env.reset()
                        s_t = self.env.state  # will not work on Atari
                        done = False
                        while not done:
                            action = np.random.randint(self.env.action_space.n)
                            observation, reward, done, _ = self.env.step(
                                action)
                            s_tp1 = self.env.state  # will not work on Atari
                            transitions.append((s_t, s_tp1))
                            total_frames += self.env.skip
                            s_t = s_tp1
                    # Dataset, Dataloader
                    dataset = SRDataset(self.env, self.projection, transitions)
                    dataloader = DataLoader(dataset,
                                            batch_size=self.SR_batch_size,
                                            shuffle=True)
                    train_losses = []
                    #Training loop
                    for epoch in range(self.SR_epochs):
                        for batch_idx, batch in enumerate(dataloader):
                            self.optimizer.zero_grad()
                            e_t, e_tp1 = batch
                            e_t = e_t.to(self.device)
                            e_tp1 = e_tp1.to(self.device)
                            mhat_t = self.mlp(e_t)
                            mhat_tp1 = self.mlp(e_tp1)
                            target = e_t + self.gamma * mhat_tp1.detach()
                            loss = self.loss_fn(mhat_t, target)
                            loss.backward()
                            self.optimizer.step()
                            train_losses.append(loss.item())
                        print("Epoch:", epoch, "Average loss",
                              np.mean(train_losses))

                    emb_reps = np.zeros(
                        [self.env.n_states, self.embedding_size])
                    SR_reps = np.zeros(
                        [self.env.n_states, self.embedding_size])
                    labels = []
                    room_size = self.env.room_size
                    for i, (state,
                            obs) in enumerate(self.env.state_dict.items()):
                        emb = np.dot(self.projection, obs.flatten())
                        emb_reps[i, :] = emb
                        with torch.no_grad():
                            emb = torch.tensor(emb).to(self.device)
                            SR = self.mlp(emb).cpu().numpy()
                        SR_reps[i, :] = SR
                        if state[0] < room_size + 1 and state[
                                1] < room_size + 1:
                            label = 0
                        elif state[0] > room_size + 1 and state[
                                1] < room_size + 1:
                            label = 1
                        elif state[0] < room_size + 1 and state[
                                1] > room_size + 1:
                            label = 2
                        elif state[0] > room_size + 1 and state[
                                1] > room_size + 1:
                            label = 3
                        else:
                            label = 4
                        labels.append(label)
                    np.save('%s_SR_reps.npy' % (self.SR_filename), SR_reps)
                    np.save('%s_emb_reps.npy' % (self.SR_filename), emb_reps)
                    np.save('%s_labels.npy' % (self.SR_filename), labels)
                elif self.SR_train_algo == 'MC':
                    pass
                elif self.SR_train_algo == 'DP':
                    # Use this to ensure same order every time
                    idx_to_state = {
                        i: state
                        for i, state in enumerate(self.env.state_dict.keys())
                    }
                    state_to_idx = {v: k for k, v in idx_to_state.items()}
                    T = np.zeros([self.env.n_states, self.env.n_states])
                    for i, s in idx_to_state.items():
                        for a in range(4):
                            self.env.state = s
                            _, _, _, _ = self.env.step(a)
                            s_tp1 = self.env.state
                            T[state_to_idx[s], state_to_idx[s_tp1]] += 0.25
                    true_SR = np.eye(self.env.n_states)
                    done = False
                    t = 0
                    while not done:
                        t += 1
                        new_SR = true_SR + (self.SR_gamma**t) * (np.matmul(
                            true_SR, T))
                        done = np.max(np.abs(true_SR - new_SR)) < 1e-10
                        true_SR = new_SR
                    self.true_SR_dict = {}
                    for s, obs in self.env.state_dict.items():
                        idx = state_to_idx[s]
                        self.true_SR_dict[s] = true_SR[idx, :]
        else:
            pass  # random projection doesn't require warmup
Пример #9
0
def train_vae(args, dtype=torch.float32):
    torch.set_default_dtype(dtype)
    state_dim = args.state_dim
    output_path = args.output_path
    # generate state pairs
    expert_traj_raw = list(pickle.load(open(args.expert_traj_path, "rb")))
    state_pairs = generate_pairs(expert_traj_raw,
                                 state_dim,
                                 args.size_per_traj,
                                 max_step=10,
                                 min_step=5)  # tune the step size if needed.
    # shuffle and split
    idx = np.arange(state_pairs.shape[0])
    np.random.shuffle(idx)
    state_pairs = state_pairs[idx, :]
    split = (state_pairs.shape[0] * 19) // 20
    state_tuples = state_pairs[:split, :]
    test_state_tuples = state_pairs[split:, :]
    print(state_tuples.shape)
    print(test_state_tuples.shape)

    goal_model = VAE(state_dim, latent_dim=128)
    optimizer_vae = torch.optim.Adam(goal_model.parameters(), lr=args.model_lr)
    save_path = '{}_softbc_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name, \
                                                  args.beta)
    writer = SummaryWriter(log_dir=os.path.join(output_path, 'runs/' +
                                                save_path))

    if args.weight:
        state_dim = state_dim + 1

        state_tuples = torch.from_numpy(state_pairs).to(dtype)
        s, t = state_tuples[:, :state_dim - 1], state_tuples[:, state_dim:2 *
                                                             state_dim]

        state_tuples_test = torch.from_numpy(test_state_tuples).to(dtype)
        s_test, t_test = state_tuples_test[:, :state_dim -
                                           1], state_tuples_test[:,
                                                                 state_dim:2 *
                                                                 state_dim]
    else:
        state_tuples = torch.from_numpy(state_pairs).to(dtype)
        s, t = state_tuples[:, :state_dim], state_tuples[:, state_dim:2 *
                                                         state_dim]

        state_tuples_test = torch.from_numpy(test_state_tuples).to(dtype)
        s_test, t_test = state_tuples_test[:, :
                                           state_dim], state_tuples_test[:,
                                                                         state_dim:
                                                                         2 *
                                                                         state_dim]

    for i in range(1, args.iter + 1):
        loss = goal_model.train(s, t, epoch=args.epoch, optimizer=optimizer_vae, \
                                        batch_size=args.optim_batch_size, beta=args.beta, use_weight=args.weight)
        next_states = goal_model.get_next_states(s_test)
        if args.weight:
            val_error = (t_test[:, -1].unsqueeze(1) *
                         (t_test[:, :-1] - next_states)**2).mean()
        else:
            val_error = ((t_test[:, :-1] - next_states)**2).mean()
        writer.add_scalar('loss/vae', loss, i)
        writer.add_scalar('valid/vae', val_error, i)
        if i % args.lr_decay_rate == 0:
            adjust_lr(optimizer_vae, 2.)
        torch.save(
            goal_model.state_dict(),
            os.path.join(output_path,
                         '{}_{}_vae.pt'.format(args.env_name, str(args.beta))))
Пример #10
0
    # network
    vae3d = VAE(
        input_channels=1, 
        output_channels=2,
        latent_dim=latent_dim,
        use_deconv=use_deconv,
        renorm=renorm,
        flag_r_train=0
    )

    vae3d.to(device)
    print(vae3d)

    # optimizer
    optimizer = optim.Adam(vae3d.parameters(), lr = lr, betas=(0.5, 0.999))
    ms = [0.3, 0.5, 0.7, 0.9]
    ms = [np.floor(m * niter).astype(int) for m in ms]
    scheduler = MultiStepLR(optimizer, milestones = ms, gamma = 0.2)

    # logger
    logger = Logger('logs', rootDir, opt['flag_rsa'], opt['case_validation'], opt['case_test'])

    # dataloader
    # dataLoader_train = COSMOS_data_loader(
    #     split='Train',
    #     patchSize=patchSize,
    #     extraction_step=extraction_step,
    #     voxel_size=voxel_size,
    #     case_validation=opt['case_validation'],
    #     case_test=opt['case_test'],
Пример #11
0
class VAE_TRAINER():
    def __init__(self, params):

        self.params = params
        self.loss_function = {
            'ms-ssim': ms_ssim_loss,
            'mse': mse_loss,
            'mix': mix_loss
        }[params["loss"]]

        # Choose device
        self.cuda = params["cuda"] and torch.cuda.is_available()
        torch.manual_seed(params["seed"])
        # Fix numeric divergence due to bug in Cudnn
        torch.backends.cudnn.benchmark = True
        self.device = torch.device("cuda" if self.cuda else "cpu")

        # Prepare data transformations
        red_size = params["img_size"]
        transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((red_size, red_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

        transform_val = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((red_size, red_size)),
            transforms.ToTensor(),
        ])

        # Initialize Data loaders
        op_dataset = RolloutObservationDataset(params["path_data"],
                                               transform_train,
                                               train=True)
        val_dataset = RolloutObservationDataset(params["path_data"],
                                                transform_val,
                                                train=False)

        self.train_loader = torch.utils.data.DataLoader(
            op_dataset,
            batch_size=params["batch_size"],
            shuffle=True,
            num_workers=0)
        self.eval_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=params["batch_size"],
            shuffle=False,
            num_workers=0)

        # Initialize model and hyperparams
        self.model = VAE(nc=3,
                         ngf=64,
                         ndf=64,
                         latent_variable_size=params["latent_size"],
                         cuda=self.cuda).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters())
        self.init_vae_model()
        self.visualize = params["visualize"]
        if self.visualize:
            self.plotter = VisdomLinePlotter(env_name=params['env'])
            self.img_plotter = VisdomImagePlotter(env_name=params['env'])
        self.alpha = params["alpha"] if params["alpha"] else 1.0

    def train(self, epoch):
        self.model.train()
        # dataset_train.load_next_buffer()
        mse_loss = 0
        ssim_loss = 0
        train_loss = 0
        # Train step
        for batch_idx, data in enumerate(self.train_loader):
            data = data.to(self.device)
            self.optimizer.zero_grad()
            recon_batch, mu, logvar = self.model(data)
            loss, mse, ssim = self.loss_function(recon_batch, data, mu, logvar,
                                                 self.alpha)
            loss.backward()

            train_loss += loss.item()
            ssim_loss += ssim
            mse_loss += mse
            self.optimizer.step()

            if batch_idx % params["log_interval"] == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data),
                    len(self.train_loader.dataset),
                    100. * batch_idx / len(self.train_loader), loss.item()))
                print('MSE: {} , SSIM: {:.4f}'.format(mse, ssim))

        step = len(self.train_loader.dataset) / float(
            self.params["batch_size"])
        mean_train_loss = train_loss / step
        mean_ssim_loss = ssim_loss / step
        mean_mse_loss = mse_loss / step
        print('-- Epoch: {} Average loss: {:.4f}'.format(
            epoch, mean_train_loss))
        print('-- Average MSE: {:.5f} Average SSIM: {:.4f}'.format(
            mean_mse_loss, mean_ssim_loss))
        if self.visualize:
            self.plotter.plot('loss', 'train', 'VAE Train Loss', epoch,
                              mean_train_loss)
        return

    def eval(self):
        self.model.eval()
        # dataset_test.load_next_buffer()
        eval_loss = 0
        mse_loss = 0
        ssim_loss = 0
        vis = True
        with torch.no_grad():
            # Eval step
            for data in self.eval_loader:
                data = data.to(self.device)
                recon_batch, mu, logvar = self.model(data)

                loss, mse, ssim = self.loss_function(recon_batch, data, mu,
                                                     logvar, self.alpha)
                eval_loss += loss.item()
                ssim_loss += ssim
                mse_loss += mse
                if vis:
                    org_title = "Epoch: " + str(epoch)
                    comparison1 = torch.cat([
                        data[:4],
                        recon_batch.view(params["batch_size"], 3,
                                         params["img_size"],
                                         params["img_size"])[:4]
                    ])
                    if self.visualize:
                        self.img_plotter.plot(comparison1, org_title)
                    vis = False

        step = len(self.eval_loader.dataset) / float(params["batch_size"])
        mean_eval_loss = eval_loss / step
        mean_ssim_loss = ssim_loss / step
        mean_mse_loss = mse_loss / step
        print('-- Eval set loss: {:.4f}'.format(mean_eval_loss))
        print('-- Eval MSE: {:.5f} Eval SSIM: {:.4f}'.format(
            mean_mse_loss, mean_ssim_loss))
        if self.visualize:
            self.plotter.plot('loss', 'eval', 'VAE Eval Loss', epoch,
                              mean_eval_loss)
            self.plotter.plot('loss', 'mse train', 'VAE MSE Loss', epoch,
                              mean_mse_loss)
            self.plotter.plot('loss', 'ssim train', 'VAE MSE Loss', epoch,
                              mean_ssim_loss)

        return mean_eval_loss

    def init_vae_model(self):
        self.vae_dir = os.path.join(self.params["logdir"], 'vae')
        check_dir(self.vae_dir, 'samples')
        if not self.params["noreload"]:  # and os.path.exists(reload_file):
            reload_file = os.path.join(self.params["vae_location"], 'best.tar')
            state = torch.load(reload_file)
            print("Reloading model at epoch {}"
                  ", with eval error {}".format(state['epoch'],
                                                state['precision']))
            self.model.load_state_dict(state['state_dict'])
            self.optimizer.load_state_dict(state['optimizer'])

    def checkpoint(self, cur_best, eval_loss):
        # Save the best and last checkpoint
        best_filename = os.path.join(self.vae_dir, 'best.tar')
        filename = os.path.join(self.vae_dir, 'checkpoint.tar')
        is_best = not cur_best or eval_loss < cur_best
        if is_best:
            cur_best = eval_loss

        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': self.model.state_dict(),
                'precision': eval_loss,
                'optimizer': self.optimizer.state_dict()
            }, is_best, filename, best_filename)
        return cur_best

    def plot(self, train, eval, epochs):
        plt.plot(epochs, train, label="train loss")
        plt.plot(epochs, eval, label="eval loss")
        plt.legend()
        plt.grid()
        plt.savefig(self.params["logdir"] + "/vae_training_curve.png")
        plt.close()
for param_setting in param_grid:

    # Copy config, set new variables
    run_config = deepcopy(config)
    run_config.freebits_param = param_setting['free_bits_param']
    run_config.mu_force_beta_param = param_setting['mu_force_beta_param']
    run_config.param_wdropout_k = param_setting['param_wdropout_k']

    vae = VAE(encoder_hidden_size=run_config.vae_encoder_hidden_size,
              decoder_hidden_size=run_config.vae_decoder_hidden_size,
              latent_size=run_config.vae_latent_size,
              vocab_size=run_config.vocab_size,
              param_wdropout_k=run_config.param_wdropout_k,
              embedding_size=run_config.embedding_size).to(run_config.device)

    optimizer = torch.optim.Adam(params=vae.parameters())

    # Initalize results writer
    path_to_results = f'{run_config.results_path}/vae'
    params2string = '-'.join(
        [f"{i}:{param_setting[i]}" for i in param_setting.keys()])

    results_writer = ResultsWriter(
        label=f'{run_config.run_label}--vae-{params2string}', )

    sentence_decoder = utils.make_sentence_decoder(cd.tokenizer, 1)

    if run_config.will_train_vae:
        print(f"Training params: {params2string}")
        train_vae(
            vae,
Пример #13
0
print('accuracy {}'.format(acc.item()))

if not os.path.exists('joint_models/'):
    os.mkdir('joint_models/')
torch.save(
    classifier.state_dict(), 'joint_models/joint_classifier_' +
    arguments.dataset_name + 'accuracy_{}'.format(acc) + '.t')

pdb.set_trace()

#
### generator
model = VAE(arguments)
if arguments.cuda:
    model = model.cuda()

if 0 & os.path.exists(model_path):
    print('loading model...')

    model.load_state_dict(torch.load(model_path))
    model = model.cuda()
else:
    print('training model...')
    optimizer = AdamNormGrad(model.parameters(), lr=arguments.lr)
    tr.experiment_vae(arguments, train_loader, val_loader, test_loader, model,
                      optimizer, dr, arguments.model_name)

results = ev.evaluate_vae(arguments, model, train_loader, test_loader, 0,
                          results_path, 'test')
pickle.dump(results, open(results_path + results_name + '.pk', 'wb'))
Пример #14
0
def run():
    model_name = args.model_name
    if model_name == 'vae_HF':
        args.number_combination = 0
    elif model_name == 'vae_ccLinIAF':
        args.number_of_flows = 1

    if args.model_name == 'vae_HF':
        model_name = model_name + '(T_' + str(args.number_of_flows) + ')'
    elif args.model_name == 'vae_ccLinIAF':
        model_name = model_name + '(K_' + str(args.number_combination) + ')'

    model_name = model_name + '_wu(' + str(args.warmup) + ')' + '_z1_' + str(args.z1_size)

    if args.z2_size > 0:
        model_name = model_name + '_z2_' + str(args.z2_size)

    print(args)

    with open('vae_experiment_log.txt', 'a') as f:
        print(args, file=f)

    # DIRECTORY FOR SAVING
    snapshots_path = 'snapshots/'
    dir = snapshots_path + model_name + '/'

    if not os.path.exists(dir):
        os.makedirs(dir)

    # LOAD DATA=========================================================================================================
    print('load data')
    if args.dataset_name == 'dynamic_mnist':
        args.dynamic_binarization = True
    else:
        args.dynamic_binarization = False

    # loading data
    train_loader, val_loader, test_loader = load_dataset(args)

    # CREATE MODEL======================================================================================================
    print('create model')
    # importing model
    if args.model_name == 'vae':
        from models.VAE import VAE
    elif args.model_name == 'vae_HF':
        from models.VAE_HF import VAE
    elif args.model_name == 'vae_ccLinIAF':
        from models.VAE_ccLinIAF import VAE
    else:
        raise Exception('Wrong name of the model!')

    model = VAE(args)
    if args.cuda:
        model.cuda()

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # ======================================================================================================================
    print('perform experiment')
    from utils.perform_experiment import experiment_vae
    experiment_vae(args, train_loader, val_loader, test_loader, model, optimizer, dir, model_name = args.model_name)
    # ======================================================================================================================
    print('-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-')
    with open('vae_experiment_log.txt', 'a') as f:
        print('-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-\n', file=f)