Exemplo n.º 1
0
def main(params):
    """ Finetunes the mBart50 model on some languages and
    then evaluates the BLEU score for each direction."""

    if params.wandb:
        wandb.init(project='mnmt', entity='nlp-mnmt-project', group='finetuning',
            config={k: v for k, v in params.__dict__.items() if isinstance(v, (float, int, str, list))})

    new_root_path = params.location
    new_name = params.name
    logger = logging.TrainLogger(params)
    logger.make_dirs()
    logger.save_params()

    # load model and tokenizer
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
    model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50").to(device)
    optimizer = torch.optim.Adam(model.parameters())

    # scale in terms of max lr
    lr_scale = params.max_lr * np.sqrt(params.warmup_steps)
    scheduler = WarmupDecay(optimizer, params.warmup_steps, 1, lr_scale=lr_scale)

    # set dropout
    model.config.dropout = params.dropout 
    model.config.attention_dropout = params.dropout

    def pipeline(dataset, langs, batch_size, max_len):

        cols = ['input_ids_' + l for l in langs]

        def tokenize_fn(example):
            """apply tokenization"""
            l_tok = []
            for lang in langs:
                encoded = tokenizer.encode(example[lang])
                encoded[0] = tokenizer.lang_code_to_id[LANG_CODES[lang]]
                l_tok.append(encoded)
            return {'input_ids_' + l: tok for l, tok in zip(langs, l_tok)}

        def pad_seqs(examples):
            """Apply padding"""
            ex_langs = list(zip(*[tuple(ex[col] for col in cols) for ex in examples]))
            ex_langs = tuple(pad_sequence(x, batch_first=True, max_len=max_len) for x in ex_langs)
            return ex_langs

        dataset = filter_languages(dataset, langs)
        dataset = dataset.map(tokenize_fn)
        dataset.set_format(type='torch', columns=cols)
        num_examples = len(dataset)
        print('-'.join(langs) + ' : {} examples.'.format(num_examples))
        dataloader = torch.utils.data.DataLoader(dataset,
                                                batch_size=batch_size,
                                                collate_fn=pad_seqs)
        return dataloader, num_examples

    # load data
    dataset = load_dataset('ted_multi')
    train_dataset = dataset['train']
    test_dataset = dataset['validation' if params.split == 'val' else 'test']

    # preprocess splits for each direction
    num_train_examples = {}
    train_dataloaders, val_dataloaders, test_dataloaders = {}, {}, {}
    for l1, l2 in combinations(params.langs, 2):
        train_dataloaders[l1+'-'+l2], num_train_examples[l1+'-'+l2] = pipeline(
            train_dataset, [l1, l2], params.batch_size, params.max_len)
        test_dataloaders[l1+'-'+l2], _ = pipeline(test_dataset, [l1, l2], params.batch_size, params.max_len)

    # print dataset sizes
    for direction, num in num_train_examples.items():
        print(direction, ': {} examples.'.format(num))

    def freeze_layers(layers, unfreeze=False):
        for n in layers:
            for parameter in model.model.encoder.layers[n].parameters():
                parameter.requires_grad = unfreeze

    # define loss function
    if params.label_smoothing is not None:
        loss_object = LabelSmoothingLoss(params.label_smoothing)
        loss_fn = lambda out, tar: loss_object(out.logits, tar)
    else:
        loss_fn = lambda out, tar: out.loss

    # train the model
    _target = torch.tensor(1.0).to(device)
    def train_step(x, y, aux=False):

        y_inp, y_tar = y[:,:-1].contiguous(), y[:,1:].contiguous()
        enc_mask, dec_mask = (x != 0), (y_inp != 0)

        x, y_inp, y_tar, enc_mask, dec_mask = to_devices(
          (x, y_inp, y_tar, enc_mask, dec_mask), device)

        model.train()
        if aux: freeze_layers(params.frozen_layers, unfreeze=True)
        output = model(input_ids=x, decoder_input_ids=y_inp,
                   labels=y_tar, attention_mask=enc_mask,
                   decoder_attention_mask=dec_mask)
        optimizer.zero_grad()
        loss = loss_fn(output, y_tar)
        loss.backward(retain_graph=aux)

        if aux: freeze_layers(params.frozen_layers)
        torch.set_grad_enabled(aux)

        x_enc = output.encoder_last_hidden_state
        y_enc = model.model.encoder(y_inp, attention_mask=dec_mask)['last_hidden_state']
        x_enc = torch.max(x_enc + -999 * (1-enc_mask.type(x_enc.dtype)).unsqueeze(-1), dim=1)[0]
        y_enc = torch.max(y_enc + -999 * (1-dec_mask.type(y_enc.dtype)).unsqueeze(-1), dim=1)[0]
        aux_loss = F.cosine_embedding_loss(x_enc, y_enc, _target)
        scaled_aux_loss = params.aux_strength * aux_loss
        
        torch.set_grad_enabled(True)
        if aux: scaled_aux_loss.backward()

        optimizer.step()
        scheduler.step()

        accuracy = accuracy_fn(output.logits, y_tar)

        return loss.item(), aux_loss.item(), accuracy.item()

    # prepare iterators
    iterators = {direction: iter(loader) for direction, loader in train_dataloaders.items()}

    # compute sampling probabilites (and set zero shot directions to 0)
    num_examples = num_train_examples.copy()
    zero_shots = [(params.zero_shot[i]+'-'+params.zero_shot[i+1]) for i in range(0, len(params.zero_shot), 2)]
    for d in zero_shots:
        num_examples[d] = 0
    directions, num_examples = list(num_examples.keys()), np.array(list(num_examples.values()))
    dir_dist = (num_examples ** params.temp) / ((num_examples ** params.temp).sum())

    #train
    losses, aux_losses, accs = [], [], []
    start_ = time.time()
    for i in range(params.train_steps):

        # sample a direction
        direction = directions[int(np.random.choice(len(num_examples), p=dir_dist))]
        try: # check iterator is not exhausted
            x, y = next(iterators[direction])
        except StopIteration:
            iterators[direction] = iter(train_dataloaders[direction])
            x, y = next(iterators[direction])
        x, y = get_direction(x, y, sample=not params.single_direction)
           
        # train on the direction
        loss, aux_loss, acc = train_step(x, y, aux=params.auxiliary)
        losses.append(loss)
        aux_losses.append(aux_loss)
        accs.append(acc)

        if i % params.verbose == 0:
            print('Batch {} Loss {:.4f} Aux Loss {:.4f} Acc {:.4f} in {:.4f} secs per batch'.format(
                i, np.mean(losses[-params.verbose:]), np.mean(aux_losses[-params.verbose:]),
                np.mean(accs[-params.verbose:]), (time.time() - start_)/(i+1)))
        if params.wandb:
            wandb.log({'train_loss':loss, 'aux_loss':aux_loss, 'train_acc':acc})

    # save results
    if params.save:
        logger.save_model(params.train_steps, model, optimizer, scheduler=scheduler)
    
    train_results = {'loss':[np.mean(losses)], 'aux_loss':[np.mean(aux_losses)], 'accuarcy':[np.mean(accs)]}
    pd.DataFrame(train_results).to_csv(logger.root_path + '/train_results.csv', index=False)

    # evaluate the model
    def evaluate(x, y, y_code, bleu):
        y_inp, y_tar = y[:,:-1].contiguous(), y[:,1:].contiguous()
        enc_mask = (x != 0)
        x, y_inp, y_tar, enc_mask = to_devices(
          (x, y_inp, y_tar, enc_mask), device)
        
        model.eval()
        y_pred = model.generate(input_ids=x, decoder_start_token_id=y_code,
            attention_mask=enc_mask, max_length=params.max_len+1,
            num_beams=params.num_beams, length_penalty=params.length_penalty,
            early_stopping=True)
        bleu(y_pred[:,1:], y_tar)

    test_results = {}
    for direction, loader in test_dataloaders.items():
        alt_direction = '-'.join(reversed(direction.split('-')))
        bleu1, bleu2 = BLEU(), BLEU()
        bleu1.set_excluded_indices([0, 2])
        bleu2.set_excluded_indices([0, 2])
        x_code = tokenizer.lang_code_to_id[LANG_CODES[direction.split('-')[0]]]
        y_code = tokenizer.lang_code_to_id[LANG_CODES[direction.split('-')[-1]]]

        start_ = time.time()
        for i, (x, y) in enumerate(loader):
            if params.test_batches is not None:
                if i > params.test_batches:
                    break

            evaluate(x, y, y_code, bleu1)
            if not params.single_direction:
                evaluate(y, x, x_code, bleu2)
            if i % params.verbose == 0:
                bl1, bl2 = bleu1.get_metric(), bleu2.get_metric()
                print('Batch {} Bleu1 {:.4f} Bleu2 {:.4f} in {:.4f} secs per batch'.format(
                    i, bl1, bl2, (time.time() - start_)/(i+1)))
                if params.wandb:
                    wandb.log({'Bleu1':bl1, 'Bleu2':bl2})

        test_results[direction] = [bleu1.get_metric()]
        test_results[alt_direction] = [bleu2.get_metric()]

    # save test_results
    pd.DataFrame(test_results).to_csv(logger.root_path + '/test_results.csv', index=False)

    if params.wandb:
        wandb.finish()
Exemplo n.º 2
0
def main():

    # load real images info or generate real images info
    inception_model_score = generative_model_score.GenerativeModelScore()
    inception_model_score.lazy_mode(True)

    import torchvision
    from torch.autograd import Variable
    from torchvision import transforms
    import tqdm
    import os

    batch_size = 64
    epochs = 1000
    img_size = 32
    save_image_interval = 5
    loss_calculation_interval = 10
    latent_dim = 10
    n_iter = 3

    wandb.login()
    wandb.init(project="AAE",
               config={
                   "batch_size": batch_size,
                   "epochs": epochs,
                   "img_size": img_size,
                   "save_image_interval": save_image_interval,
                   "loss_calculation_interval": loss_calculation_interval,
                   "latent_dim": latent_dim,
                   "n_iter": n_iter,
               })
    config = wandb.config

    train_loader, validation_loader, test_loader = get_celebA_dataset(
        batch_size, img_size)
    # train_loader = get_cifar1_dataset(batch_size)

    image_shape = [3, img_size, img_size]

    import hashlib
    real_images_info_file_name = hashlib.md5(
        str(train_loader.dataset).encode()).hexdigest() + '.pickle'

    if os.path.exists('./inception_model_info/' + real_images_info_file_name):
        print("Using generated real image info.")
        print(train_loader.dataset)
        inception_model_score.load_real_images_info('./inception_model_info/' +
                                                    real_images_info_file_name)
    else:
        inception_model_score.model_to('cuda')

        #put real image
        for each_batch in train_loader:
            X_train_batch = each_batch[0]
            inception_model_score.put_real(X_train_batch)

        #generate real images info
        inception_model_score.lazy_forward(batch_size=64,
                                           device='cuda',
                                           real_forward=True)
        inception_model_score.calculate_real_image_statistics()
        #save real images info for next experiments
        inception_model_score.save_real_images_info('./inception_model_info/' +
                                                    real_images_info_file_name)
        #offload inception_model
        inception_model_score.model_to('cpu')

    encoder = Encoder(latent_dim, image_shape).cuda()
    decoder = Decoder(latent_dim, image_shape).cuda()
    discriminator = Discriminator(latent_dim).cuda()
    ae_optimizer = torch.optim.Adam(itertools.chain(encoder.parameters(),
                                                    decoder.parameters()),
                                    lr=1e-4)
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
    g_optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-4)

    r_losses = []
    d_losses = []
    g_losses = []
    precisions = []
    recalls = []
    fids = []
    inception_scores_real = []
    inception_scores_fake = []

    for i in range(0, epochs):
        batch_count = 0

        for each_batch in tqdm.tqdm(train_loader):
            batch_count += 1
            X_train_batch = Variable(each_batch[0]).cuda()
            r_loss = update_autoencoder(ae_optimizer, X_train_batch, encoder,
                                        decoder)

            for iter_ in range(n_iter):
                d_loss = update_discriminator(d_optimizer, X_train_batch,
                                              encoder, discriminator,
                                              latent_dim)

            g_loss = update_generator(g_optimizer, X_train_batch, encoder,
                                      discriminator)

            sampled_images = sample_image(encoder, decoder,
                                          X_train_batch).detach().cpu()

            if i % loss_calculation_interval == 0:
                inception_model_score.put_fake(sampled_images)

        if i % save_image_interval == 0:
            image = save_images(n_row=10,
                                epoch=i,
                                latent_dim=latent_dim,
                                model=decoder)
            wandb.log({'image': wandb.Image(image, caption='%s_epochs' % i)},
                      step=i)

        if i % loss_calculation_interval == 0:
            #offload all GAN model to cpu and onload inception model to gpu
            encoder = encoder.to('cpu')
            decoder = decoder.to('cpu')
            discriminator = discriminator.to('cpu')
            inception_model_score.model_to('cuda')

            #generate fake images info
            inception_model_score.lazy_forward(batch_size=64,
                                               device='cuda',
                                               fake_forward=True)
            inception_model_score.calculate_fake_image_statistics()
            metrics = inception_model_score.calculate_generative_score()

            #onload all GAN model to gpu and offload inception model to cpu
            inception_model_score.model_to('cpu')
            encoder = encoder.to('cuda')
            decoder = decoder.to('cuda')
            discriminator = discriminator.to('cuda')

            precision, recall, fid, inception_score_real, inception_score_fake, density, coverage = \
                metrics['precision'], metrics['recall'], metrics['fid'], metrics['real_is'], metrics['fake_is'], metrics['density'], metrics['coverage']

            wandb.log(
                {
                    "precision": precision,
                    "recall": recall,
                    "fid": fid,
                    "inception_score_real": inception_score_real,
                    "inception_score_fake": inception_score_fake,
                    "density": density,
                    "coverage": coverage
                },
                step=i)

            r_losses.append(r_loss)
            d_losses.append(d_loss)
            g_losses.append(g_loss)
            precisions.append(precision)
            recalls.append(recall)
            fids.append(fid)
            inception_scores_real.append(inception_score_real)
            inception_scores_fake.append(inception_score_fake)
            save_scores_and_print(i + 1, epochs, r_loss, d_loss, g_loss,
                                  precision, recall, fid, inception_score_real,
                                  inception_score_fake)

        inception_model_score.clear_fake()
    save_losses(epochs, loss_calculation_interval, r_losses, d_losses,
                g_losses)
    wandb.finish()
Exemplo n.º 3
0
    def ppo_train(self,
                  env_fn,
                  epochs,
                  gamma,
                  lam,
                  steps_per_epoch,
                  train_pi_iters,
                  pi_lr,
                  train_vf_iters,
                  vf_lr,
                  penalty_lr,
                  cost_lim,
                  clip_ratio,
                  max_ep_len,
                  save_every,
                  wandb_write=True,
                  logger_kwargs=dict()):

        # 4 million env interactions
        if wandb_write:
            wandb.init(project="TrainingExperts",
                       group="ppo_runs",
                       name='ppo_pen_' + self.config_name)

        # Set up logger
        logger = EpochLogger(**logger_kwargs)
        logger.save_config(locals())

        # Make environment
        env = env_fn()

        obs_dim = env.observation_space.shape
        act_dim = env.action_space.shape

        self.ac = self.actor_critic(env.observation_space, env.action_space,
                                    **self.ac_kwargs)

        # Set up Torch saver for logger setup
        logger.setup_pytorch_saver(self.ac)

        if wandb_write:
            wandb.watch(self.ac)

        # Set up experience buffer
        self.local_steps_per_epoch = int(steps_per_epoch / num_procs())
        self.buf = CostPOBuffer(obs_dim, act_dim, self.local_steps_per_epoch,
                                gamma, lam)

        # Set up optimizers for policy and value function
        pi_optimizer = Adam(self.ac.pi.parameters(), lr=pi_lr)
        vf_optimizer = Adam(self.ac.v.parameters(), lr=vf_lr)

        penalty = np.log(max(np.exp(self.penalty_init) - 1, 1e-8))

        mov_avg_ret, mov_avg_cost = 0, 0

        # Prepare for interaction with environment
        start_time = time.time()
        o, r, d, c, ep_ret, ep_cost, ep_len, cum_cost, cum_reward = env.reset(
        ), 0, False, 0, 0, 0, 0, 0, 0
        rew_mov_avg, cost_mov_avg = [], []

        cur_penalty = self.penalty_init_param

        for epoch in range(epochs):
            for t in range(self.local_steps_per_epoch):
                a, v, vc, logp = self.ac.step(
                    torch.as_tensor(o, dtype=torch.float32))

                # env.step => Take action
                next_o, r, d, info = env.step(a)

                # Include penalty on cost
                c = info.get('cost', 0)

                # Track cumulative cost over training
                cum_reward += r
                cum_cost += c

                ep_ret += r
                ep_cost += c
                ep_len += 1

                r_total = r - cur_penalty * c
                r_total /= (1 + cur_penalty)

                self.buf.store(o, a, r_total, v, 0, 0, logp, info)

                # save and log
                logger.store(VVals=v)

                # Update obs (critical!)
                o = next_o

                timeout = ep_len == max_ep_len
                terminal = d or timeout
                epoch_ended = t == self.local_steps_per_epoch - 1

                if terminal or epoch_ended:
                    if epoch_ended and not terminal:
                        print(
                            'Warning: trajectory cut off by epoch at %d steps.'
                            % ep_len,
                            flush=True)
                    # if trajectory didn't reach terminal state, bootstrap value target
                    if timeout or epoch_ended:
                        _, v, _, _ = self.ac.step(
                            torch.as_tensor(o, dtype=torch.float32))
                        last_v = v
                        last_vc = 0

                    else:
                        last_v = 0

                    self.buf.finish_path(last_v, last_vc)

                    if terminal:
                        # only save EpRet / EpLen if trajectory finished
                        print("end of episode return: ", ep_ret)
                        logger.store(EpRet=ep_ret,
                                     EpLen=ep_len,
                                     EpCost=ep_cost)

                        # average ep ret and cost
                        avg_ep_ret = ep_ret
                        avg_ep_cost = ep_cost
                        episode_metrics = {
                            'average ep ret': avg_ep_ret,
                            'average ep cost': avg_ep_cost
                        }

                        if wandb_write:
                            wandb.log(episode_metrics)

                    # Reset environment
                    o, r, d, c, ep_ret, ep_len, ep_cost = env.reset(
                    ), 0, False, 0, 0, 0, 0

            # Save model and save last trajectory
            # print("About to state save")
            if (epoch % save_every == 0) or (epoch == epochs - 1):
                logger.save_state({'env': env}, None)

            # Perform PPO update!
            cur_penalty, mov_avg_ret, mov_avg_cost, vf_loss_avg, pi_loss_avg = \
                                         update_ppo(self.ac, cur_penalty, clip_ratio,
                                                    logger,
                                                    self.buf,
                                                    train_pi_iters, pi_optimizer,
                                                    train_vf_iters, vf_optimizer,
                                                    cost_lim, penalty_lr,
                                                    rew_mov_avg, cost_mov_avg)

            if wandb_write:
                update_metrics = {
                    '10p mov avg ret': mov_avg_ret,
                    '10p mov avg cost': mov_avg_cost,
                    'value function loss': vf_loss_avg,
                    'policy loss': pi_loss_avg,
                    'current penalty': cur_penalty
                }

                wandb.log(update_metrics)
            #  Cumulative cost calculations
            cumulative_cost = mpi_sum(cum_cost)
            cumulative_reward = mpi_sum(cum_reward)

            cost_rate = cumulative_cost / ((epoch + 1) * steps_per_epoch)
            reward_rate = cumulative_reward / ((epoch + 1) * steps_per_epoch)

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('EpCost', with_min_and_max=True)
            logger.log_tabular('VVals', with_min_and_max=True)
            logger.log_tabular('TotalEnvInteracts',
                               (epoch + 1) * steps_per_epoch)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossV', average_only=True)
            logger.log_tabular('DeltaLossPi', average_only=True)
            logger.log_tabular('DeltaLossV', average_only=True)
            logger.log_tabular('Entropy', average_only=True)
            logger.log_tabular('KL', average_only=True)
            logger.log_tabular('ClipFrac', average_only=True)
            logger.log_tabular('StopIter', average_only=True)
            logger.log_tabular('Time', time.time() - start_time)
            logger.dump_tabular()

            if wandb_write:
                log_metrics = {
                    'cost rate': cost_rate,
                    'reward rate': reward_rate
                }
                wandb.log(log_metrics)

                wandb.finish()
Exemplo n.º 4
0
def main(args):
    train_loader, _ = data_helper.get_data(args.dataset, args.batch_size,
                                           args.image_size, args.environment)
    if args.wandb:
        wandb_name = "%s[%d]_%s" % (args.dataset, args.image_size,
                                    args.model_name)
        wandb.login()
        wandb.init(project="AAE", config=args, name=wandb_name)
    inception_model_score = load_inception_model(train_loader, args.dataset,
                                                 args.image_size,
                                                 args.environment)
    ae_optimizer, d_optimizer, decoder, discriminator, encoder, g_optimizer, mapper = \
        model.get_aae_model_and_optimizer(args)
    if args.model_name == 'mimic':
        mapper = model.Mimic(args.latent_dim, args.latent_dim,
                             args.mapper_inter_nz,
                             args.mapper_inter_layer).to(args.device)
        decoder, encoder = pretrain_autoencoder(ae_optimizer, args, decoder,
                                                encoder, train_loader)
    if args.model_name == 'non-prior':
        mapper, m_optimizer = model.get_nonprior_model_and_optimizer(args)
    if args.model_name == 'learning-prior':
        mapper, m_optimizer, discriminator_forpl, dpl_optimizer = \
            model.get_learning_prior_model_and_optimizer(args)
        decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-4)

    global start_time
    start_time = time.time()
    if args.pretrain_epoch > 0:
        pretrain_autoencoder(ae_optimizer, args, decoder, encoder,
                             train_loader)

    log_dict, log, log2 = {}, {}, {}
    for i in range(0, args.epochs):
        log_dict, log, log2 = {}, {}, {}
        if args.time_limit and timeout(args.time_limit, start_time): break
        encoded_feature_list = []
        for each_batch in tqdm.tqdm(train_loader,
                                    desc="train[%d/%d]" % (i, args.epochs)):
            each_batch = each_batch[0].to(args.device)
            if args.model_name in ['aae', 'mask_aae']:
                log = model.update_aae(ae_optimizer, args, d_optimizer,
                                       decoder, discriminator, each_batch,
                                       encoder, g_optimizer, args.latent_dim)
            elif args.model_name == 'mimic':
                log, encoded_feature = \
                    model.update_autoencoder(ae_optimizer, each_batch, encoder, decoder, return_encoded_feature=True)
                encoded_feature_list.append(encoded_feature)
            elif args.model_name == 'non-prior':
                log, encoded_feature = model.update_autoencoder(
                    ae_optimizer,
                    each_batch,
                    encoder,
                    decoder,
                    return_encoded_feature_gpu=True,
                    flag_retain_graph=False)
                log2 = model.update_posterior_part(args, mapper, discriminator,
                                                   m_optimizer, d_optimizer,
                                                   encoded_feature)
            elif args.model_name == 'learning-prior':
                log = model.update_aae_with_mappedz(args, ae_optimizer,
                                                    d_optimizer, decoder,
                                                    discriminator, mapper,
                                                    each_batch, encoder,
                                                    g_optimizer)
                log2 = model.update_mapper_with_discriminator_forpl(
                    args, dpl_optimizer, decoder_optimizer, m_optimizer,
                    discriminator_forpl, decoder, mapper, each_batch)
            if args.model_name == 'mimic':
                g_loss = model.train_mapper(args, encoder, mapper, args.device,
                                            args.lr, args.batch_size,
                                            encoded_feature_list)

        log_dict.update(log)
        log_dict.update(log2)

        # wandb log를 남기고, time_check와 time_limit 옵션이 둘다 없을때만, log interval마다 기록을 남김
        if args.wandb and not args.time_check and not args.time_limit:
            decoder, discriminator, encoder, mapper = log_and_write_pca(
                args, decoder, discriminator, encoder, i,
                inception_model_score, mapper, log_dict)

    # wandb log를 남기고, time_check 또는 time_limit 옵션 둘 중 하나라도 있으면, 최후에 기록을 남김
    if args.wandb and (args.time_check or args.time_limit):
        decoder, discriminator, encoder, mapper = log_and_write_pca(
            args, decoder, discriminator, encoder, i, inception_model_score,
            mapper, log_dict)

    save_models(args, decoder, encoder, mapper)

    if args.wandb:
        wandb.finish()
Exemplo n.º 5
0
def main_per_process(rank, world_size, args):
    init_process(rank, world_size)
    start_epoch = 1
    if args.wandb and rank == 0:
        run_name = get_run_name(args)
        wandb.init(project='myproject', entity='myaccount')
        wandb.run.name = run_name
        wandb.config.update(args)
    if rank == 0:
        output_cuda_info()

    # load dataset
    train_val_split = 0.2
    batch_size_per_proc = int(args.batch_size / world_size)
    train_set, val_set, test_set = load_cifar10(train_val_split,
                                                args.pretrained)

    # create sampler for ddp
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_set, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = torch.utils.data.distributed.DistributedSampler(
        val_set, num_replicas=world_size, rank=rank, shuffle=False)

    # create data loader for ddp
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size_per_proc,
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=batch_size_per_proc,
                                             num_workers=args.num_workers,
                                             pin_memory=True,
                                             sampler=val_sampler)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.batch_size,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    # create ddp model
    model = make_model(args.model,
                       10,
                       pretrained=args.pretrained,
                       fix_param=args.fixparam)
    model = model.to(rank)
    # model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    if rank == 0:
        output_summary(ddp_model, train_loader)

    # settings for training
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(ddp_model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=args.epoch)

    # synchronize
    dist.barrier()

    # start training
    print(f'[{datetime.now()}]#{rank}: start training')
    for epoch in range(start_epoch, start_epoch + args.epoch):
        if rank == 0:
            print(f'Epoch[{epoch}/{args.epoch}]')
        train_sampler.set_epoch(epoch)
        val_sampler.set_epoch(epoch)
        dist.barrier()  # synchronize

        # train and validate
        train_loss, train_acc = train_epoch(ddp_model, train_loader, optimizer,
                                            criterion, rank)
        val_loss, val_acc = validate_epoch(ddp_model, val_loader, criterion,
                                           rank)
        dist.barrier()  # synchronize

        # sharing loss and accuracy among all gpus(processes)
        train_loss_list = [0.] * world_size
        train_acc_list = [0.] * world_size
        val_loss_list = [0.] * world_size
        val_acc_list = [0.] * world_size
        dist.all_gather_object(train_loss_list, train_loss)
        dist.all_gather_object(train_acc_list, train_acc)
        dist.all_gather_object(val_loss_list, val_loss)
        dist.all_gather_object(val_acc_list, val_acc)

        # save data to wandb
        if args.wandb and rank == 0:
            avg_train_loss = sum(train_loss_list) / world_size
            avg_train_acc = sum(train_acc_list) / world_size
            avg_val_loss = sum(val_loss_list) / world_size
            avg_val_acc = sum(val_acc_list) / world_size
            wandb.log({
                'acc': avg_train_acc,
                'loss': avg_train_loss,
                'val_acc': avg_val_acc,
                'val_loss': avg_val_loss,
                'lr': scheduler.get_last_lr()[0]
            })
        scheduler.step()
    print(f'[{datetime.now()}]#{rank}: finished training')

    if rank == 0:
        print('# final test')
        test_loss, test_acc, class_acc = final_test(model, test_loader,
                                                    criterion, rank)
        for key, value in class_acc.items():
            print(f'{key} : {value: .3f}')

        # save data to wandb
        if args.wandb:
            wandb.log({'test_acc': test_acc, 'test_loss': test_loss})
            wandb.finish()
        print('# all finished')
Exemplo n.º 6
0
 def end(self):
     wandb.finish()
Exemplo n.º 7
0
def main(args):

    global inception_model_score

    # load real images info or generate real images info
    model_name = args.model_name
    #torch.cuda.set_device(device=args.device)
    device = args.device
    epochs = args.epochs
    batch_size = args.batch_size
    img_size = args.img_size
    save_image_interval = args.save_image_interval
    loss_calculation_interval = args.loss_calculation_interval
    latent_dim = args.latent_dim
    project_name = args.project_name
    dataset = args.dataset
    lr = args.lr
    n_iter = args.n_iter

    fixed_z = make_fixed_z(model_name, latent_dim, device)

    image_shape = [3, img_size, img_size]

    time_limit_sec = timeparse(args.time_limit)

    if args.wandb:
        wandb.login()
        wandb_name = dataset + ',' + model_name + ',' + str(
            img_size) + ",convchange"
        if args.run_test: wandb_name += ', test run'
        wandb.init(project=project_name, config=args, name=wandb_name)
        config = wandb.config
    '''
    customize
    '''
    if model_name in ['vanilla']:
        args.mapper_inter_layer = 0

    if model_name in [
            'vanilla', 'pointMapping_but_aae', 'non-prior', 'mimic+non-prior',
            'vanilla-mimic'
    ]:
        encoder = Encoder(latent_dim, img_size).to(device)
        decoder = Decoder(latent_dim, img_size).to(device)
        discriminator = Discriminator(latent_dim).to(device)
        ae_optimizer = torch.optim.Adam(itertools.chain(
            encoder.parameters(), decoder.parameters()),
                                        lr=lr)
        d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

    elif model_name in [
            'ulearning', 'ulearning_point', 'mimic_at_last', 'mimic'
    ]:
        encoder = Encoder(latent_dim, img_size).to(device)
        decoder = Decoder(latent_dim, img_size).to(device)
        discriminator = None
        d_optimizer = None
        ae_optimizer = torch.optim.Adam(itertools.chain(
            encoder.parameters(), decoder.parameters()),
                                        lr=lr)

    ###########################################
    #####              Score              #####
    ###########################################
    inception_model_score.lazy_mode(True)
    '''
    dataset 채워주세요!
    customize
    '''
    if dataset == 'CelebA':
        train_loader = get_celebA_dataset(batch_size, img_size)
    elif dataset == 'FFHQ':
        train_loader = get_ffhq_thumbnails(batch_size, img_size)
    elif dataset == 'mnist':
        train_loader = get_mnist_dataset(batch_size, img_size)
    elif dataset == 'mnist_fashion':
        train_loader = get_mnist_fashion_dataset(batch_size, img_size)
    elif dataset == 'emnist':
        train_loader = get_emnist_dataset(batch_size, img_size)
    elif dataset == 'LSUN_dining_room':
        #wget http://dl.yf.io/lsun/scenes/dining_room_train_lmdb.zip
        #unzip dining_room_train_lmdb.zip
        #located dining_room_train_lmdb folder in dataset directory
        train_loader = get_lsun_dataset(batch_size,
                                        img_size,
                                        classes='dining_room_train')
    elif dataset == 'LSUN_classroom':
        #wget http://dl.yf.io/lsun/scenes/classroom_train_lmdb.zip
        #unzip classroom_train_lmdb.zip
        #located classroom_train_lmdb folder in dataset directory
        train_loader = get_lsun_dataset(batch_size,
                                        img_size,
                                        classes='classroom_train')
    elif dataset == 'LSUN_conference':
        #wget http://dl.yf.io/lsun/scenes/conference_room_train_lmdb.zip
        #unzip conference_room_train_lmdb.zip
        #located conference_room_train_lmdb folder in dataset directory
        train_loader = get_lsun_dataset(batch_size,
                                        img_size,
                                        classes='conference_room_train')
    elif dataset == 'LSUN_churches':
        #wget http://dl.yf.io/lsun/scenes/church_outdoor_train_lmdb.zip
        #unzip church_outdoor_train_lmdb.zip
        #located church_outdoor_train_lmdb folder in dataset directory
        train_loader = get_lsun_dataset(batch_size,
                                        img_size,
                                        classes='church_outdoor_train')
    else:
        print("dataset is forced selected to cifar10")
        train_loader = get_cifar1_dataset(batch_size, img_size)

    real_images_info_file_name = hashlib.md5(
        str(train_loader.dataset).encode()).hexdigest() + '.pickle'
    if args.run_test: real_images_info_file_name += '.run_test'

    os.makedirs('../../inception_model_info', exist_ok=True)
    if os.path.exists('../../inception_model_info/' +
                      real_images_info_file_name):
        print("Using generated real image info.")
        print(train_loader.dataset)
        inception_model_score.load_real_images_info(
            '../../inception_model_info/' + real_images_info_file_name)

    else:
        inception_model_score.model_to(device)

        #put real image
        for each_batch in tqdm.tqdm(train_loader, desc='insert real dataset'):
            X_train_batch = each_batch[0]
            inception_model_score.put_real(X_train_batch)
            if args.run_test: break

        #generate real images info
        inception_model_score.lazy_forward(batch_size=256,
                                           device=device,
                                           real_forward=True)
        inception_model_score.calculate_real_image_statistics()
        #save real images info for next experiments
        inception_model_score.save_real_images_info(
            '../../inception_model_info/' + real_images_info_file_name)
        #offload inception_model
        inception_model_score.model_to('cpu')

    if args.mapper_inter_layer > 0:
        if model_name in ['ulearning_point', 'mimic_at_last']:
            mapper = EachLatentMapping(
                nz=args.latent_dim,
                inter_nz=args.mapper_inter_nz,
                linear_num=args.mapper_inter_layer).to(device)
            m_optimizer = None
        elif model_name in ['pointMapping_but_aae']:
            mapper = EachLatentMapping(
                nz=args.latent_dim,
                inter_nz=args.mapper_inter_nz,
                linear_num=args.mapper_inter_layer).to(device)
            m_optimizer = torch.optim.Adam(mapper.parameters(), lr=lr)
        elif model_name in ['ulearning', 'non-prior']:
            mapper = Mapping(args.latent_dim, args.mapper_inter_nz,
                             args.mapper_inter_layer).to(device)
            m_optimizer = torch.optim.Adam(mapper.parameters(), lr=lr)
        elif model_name in ['mimic', 'vanilla-mimic']:
            mapper = Mimic(args.latent_dim, args.latent_dim,
                           args.mapper_inter_nz,
                           args.mapper_inter_layer).to(device)
            m_optimizer = torch.optim.Adam(mapper.parameters(),
                                           lr=lr,
                                           weight_decay=1e-3)
        elif model_name in [
                'mimic+non-prior',
        ]:
            mapper = MimicStack(args.latent_dim, args.latent_dim,
                                args.mapper_inter_nz,
                                args.mapper_inter_layer).to(device)
            m_optimizer = torch.optim.Adam(mapper.parameters(), lr=lr)
    else:
        # case vanilla and there is no mapper
        mapper = lambda x: x
        m_optimizer = None

    if args.load_netE != '': load_model(encoder, args.load_netE)
    if args.load_netM != '': load_model(mapper, args.load_netM)
    if args.load_netD != '': load_model(decoder, args.load_netD)

    time_start_run = time.time()

    AE_pretrain(args, train_loader, device, ae_optimizer, encoder, decoder)

    M_pretrain(args, train_loader, device, d_optimizer, m_optimizer, mapper,
               encoder, discriminator)

    # train phase
    i = 0
    loss_log = {}
    for i in range(1, epochs + 1):
        loss_log = train_main(args, train_loader, i, device, ae_optimizer,
                              m_optimizer, d_optimizer, encoder, decoder,
                              mapper, discriminator)
        loss_log.update({'spend time': time.time() - time_start_run})

        if check_time_over(time_start_run, time_limit_sec) == True:
            print("time limit over")
            break

        if i % save_image_interval == 0:
            insert_sample_image_inception(args, i, epochs, train_loader,
                                          mapper, decoder,
                                          inception_model_score)
            matric = gen_matric(wandb, args, train_loader, encoder, mapper,
                                decoder, discriminator, inception_model_score)
            loss_log.update(matric)
        if args.wandb:
            wandb_update(wandb, i, args, train_loader, encoder, mapper,
                         decoder, device, fixed_z, loss_log)
        else:
            print(loss_log)

        if i % args.save_model_every == 0:
            now_time = str(datetime.now())
            save_model([encoder, mapper, decoder], [
                "%s[%d epoch].netE" % (now_time, i),
                "%s[%d epoch].netM" % (now_time, i),
                "%s[%d epoch].netD" % (now_time, i)
            ])

    #make last matric
    if model_name in ['mimic_at_last']:
        M_train_at_last(args, train_loader, device, d_optimizer, m_optimizer,
                        mapper, encoder, discriminator)
    if i % save_image_interval != 0:
        insert_sample_image_inception(args, i, epochs, train_loader, mapper,
                                      decoder, inception_model_score)
        matric = gen_matric(wandb, args, train_loader, encoder, mapper,
                            decoder, discriminator, inception_model_score)
    if args.wandb:
        loss_log.update(matric)
        wandb_update(wandb, i, args, train_loader, encoder, mapper, decoder,
                     device, fixed_z, loss_log)

    now_time = str(datetime.now())
    save_model([encoder, mapper, decoder], [
        "%s[%d epoch].netE" % (now_time, i),
        "%s[%d epoch].netM" % (now_time, i),
        "%s[%d epoch].netD" % (now_time, i)
    ])

    if args.wandb: wandb.finish()
Exemplo n.º 8
0
def train_vae(model, config, train_loader, val_loader, project_name='vae'):
    print(f"\nTraining will run on device: {device}")
    print(f"\nStarting training with config:")
    print(json.dumps(config, sort_keys=False, indent=4))

    # Initialize a new wandb run
    wandb.init(project=project_name, config=config)
    wandb.watch(model)

    # define optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], betas=(0.9, 0.999))

    # Set learning rate scheduler
    if "lr_decay" in config:
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lambda_lr(**config["lr_decay"])
        )
    
    # linear deterministic warmup
    if config["kl_warmup"]:
        gamma = DeterministicWarmup(n=50, t_max=1)
    else:
        gamma = DeterministicWarmup(n=1, t_max=1)

    # Run training
    for epoch in range(config['epochs']):
        prog_str = f"{epoch+1}/{config['epochs']}"
        print(f'Epoch {prog_str}')
        
        # Train Epoch
        model.train()
        alpha = next(gamma)
        elbo_train = []
        kld_train = []
        recon_train = []
        for x, _ in iter(train_loader):
            batch_size = x.size(0)

            # Pass batch through model
            x = x.view(batch_size, -1)
            x = Variable(x).to(device)
            x_hat, kld = model(x)

            # Compute losses
            recon = torch.mean(bce_loss(x_hat, x))
            kl = torch.mean(kld)
            loss = recon + alpha * kl

            # Update gradients
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # save losses
            elbo_train.append(torch.mean(-loss).item())
            kld_train.append(torch.mean(kl).item())
            recon_train.append(torch.mean(recon).item())
        
        # Log train stuff
        wandb.log({
            'recon_train': torch.tensor(recon_train).mean(),
            'kl_train': torch.tensor(kld_train).mean(),
            'elbo_train': torch.tensor(elbo_train).mean()
        }, commit=False)

        # Update scheduler
        if "lr_decay" in config:
            scheduler.step()

        # Validation epoch
        model.eval()
        with torch.no_grad():
            elbo_val = []
            kld_val = []
            recon_val = []
            for x, _ in iter(val_loader):
                batch_size = x.size(0)

                # Pass batch through model
                x = x.view(batch_size, -1)
                x = Variable(x).to(device)
                x_hat, kld = model(x)

                # Compute losses
                recon = torch.mean(bce_loss(x_hat, x))
                kl = torch.mean(kld)
                loss = recon + alpha * kl

                # save losses
                elbo_val.append(torch.mean(-loss).item())
                kld_val.append(torch.mean(kld).item())
                recon_val.append(torch.mean(recon).item())
        
        # Log validation stuff
        wandb.log({
            'recon_val': torch.tensor(recon_val).mean(),
            'kl_val': torch.tensor(kld_val).mean(),
            'elbo_val': torch.tensor(elbo_val).mean()
        }, commit=False)

        # Sample from model
        if isinstance(config['z_dim'], list):
            x_mu = Variable(torch.randn(16, config['z_dim'][0])).to(device)
        else:
            x_mu = Variable(torch.randn(16, config['z_dim'])).to(device)
        x_sample = model.sample(x_mu)

        # Log images to wandb
        log_images(x_hat, x_sample, epoch)
    
    # Save final model
    torch.save(model, SAVE_NAME)
    wandb.save(SAVE_NAME)

    # Finalize training
    wandb.finish()
Exemplo n.º 9
0
def run_policy(env, get_action, max_ep_len=None, num_episodes=100, render=True, record=False, record_project= 'benchmarking', record_name = 'trained' , data_path='', config_name='test', max_len_rb=100, benchmark=False, log_prefix=''):
    assert env is not None, \
        "Environment not found!\n\n It looks like the environment wasn't saved, " + \
        "and we can't run the agent in it. :( \n\n Check out the readthedocs " + \
        "page on Experiment Outputs for how to handle this situation."

    logger = EpochLogger()
    o, r, d, ep_ret, ep_len, n = env.reset(), 0, False, 0, 0, 0
    ep_cost = 0
    local_steps_per_epoch = int(4000 / num_procs())

    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    rew_mov_avg_10 = []
    cost_mov_avg_10 = []

    if benchmark:
        ep_costs = []
        ep_rewards = []

    if record:
        wandb.login()
        # 4 million env interactions
        wandb.init(project=record_project, name=record_name)

        rb = ReplayBuffer(size=10000,
                          env_dict={
                              "obs": {"shape": obs_dim},
                              "act": {"shape": act_dim},
                              "rew": {},
                              "next_obs": {"shape": obs_dim},
                              "done": {}})

        # columns = ['observation', 'action', 'reward', 'cost', 'done']
        # sim_data = pd.DataFrame(index=[0], columns=columns)

    while n < num_episodes:
        if render:
            env.render()
            time.sleep(1e-3)

        a = get_action(o)
        next_o, r, d, info = env.step(a)

        if record:
            # buf.store(next_o, a, r, None, info['cost'], None, None, None)
            done_int = int(d==True)
            rb.add(obs=o, act=a, rew=r, next_obs=next_o, done=done_int)

        ep_ret += r
        ep_len += 1
        ep_cost += info['cost']

        # Important!
        o = next_o

        if d or (ep_len == max_ep_len):
            # finish recording and save csv
            if record:
                rb.on_episode_end()

                # make directory if does not exist
                if not os.path.exists(data_path + config_name + '_episodes'):
                    os.makedirs(data_path + config_name + '_episodes')

                # buf = CostPOBuffer(obs_dim, act_dim, local_steps_per_epoch, 0.99, 0.99)

            if len(rew_mov_avg_10) >= 25:
                rew_mov_avg_10.pop(0)
                cost_mov_avg_10.pop(0)

            rew_mov_avg_10.append(ep_ret)
            cost_mov_avg_10.append(ep_cost)

            mov_avg_ret = np.mean(rew_mov_avg_10)
            mov_avg_cost = np.mean(cost_mov_avg_10)

            expert_metrics = {log_prefix + 'episode return': ep_ret,
                              log_prefix + 'episode cost': ep_cost,
                              # 'cumulative return': cum_ret,
                              # 'cumulative cost': cum_cost,
                              log_prefix + '25ep mov avg return': mov_avg_ret,
                              log_prefix + '25ep mov avg cost': mov_avg_cost
                              }

            if benchmark:
                ep_rewards.append(ep_ret)
                ep_costs.append(ep_cost)

            wandb.log(expert_metrics)
            logger.store(EpRet=ep_ret, EpLen=ep_len, EpCost=ep_cost)
            print('Episode %d \t EpRet %.3f \t EpLen %d \t EpCost %d' % (n, ep_ret, ep_len, ep_cost))
            o, r, d, ep_ret, ep_len, ep_cost = env.reset(), 0, False, 0, 0, 0
            n += 1


    logger.log_tabular('EpRet', with_min_and_max=True)
    logger.log_tabular('EpLen', average_only=True)
    logger.dump_tabular()

    if record:
        print("saving final buffer")
        bufname_pk = data_path + config_name + '_episodes/sim_data_' + str(int(num_episodes)) + '_buffer.pkl'
        file_pi = open(bufname_pk, 'wb')
        pickle.dump(rb.get_all_transitions(), file_pi)
        wandb.finish()

        return rb

    if benchmark:
        return ep_rewards, ep_costs
Exemplo n.º 10
0
def main():
  parser = argparse.ArgumentParser(description='Process some integers.')
  parser.add_argument('--engine', default='default')
  parser.add_argument('--optimizer', default='sgdm')
  parser.add_argument('--learning_rate', type=float, default=0.0001)
  parser.add_argument('--momentum', type=float, default=0.99)
  parser.add_argument('--batch_size', type=int, default=8)
  parser.add_argument('--epochs', type=int, default=100)
  parser.add_argument('--experiment_name')
  args = parser.parse_args()

  experiment_name = args.experiment_name
  if experiment_name is None:
    experiment_name = wandb.util.generate_id()
  wandb.init(project="kalman-fisher", group=experiment_name,
    config={
      "engine": args.engine,
      "optimizer": args.optimizer,
      "batch_size": args.batch_size,
      "learning_rate": args.learning_rate,
      "momentum": args.momentum,
    })

  hparams = {
    "learning_rate": args.learning_rate,
    "momentum": args.momentum
  }

  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  x_train, x_test = x_train / 255.0, x_test / 255.0
  # Add a channels dimension
  #x_train = x_train[..., tf.newaxis].astype("float32")
  #x_test = x_test[..., tf.newaxis].astype("float32")

  train_ds = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(10000).batch(args.batch_size).prefetch(tf.data.AUTOTUNE)
  test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(args.batch_size).prefetch(tf.data.AUTOTUNE)

  loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
  hessian_fn = crossentropy_hessian_fn
  accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')

  #model = CNNModel()
  model = VGGModel()
  #model = tf.keras.applications.ResNet50(weights=None, classes=10, input_shape=x_train.shape[1:])
  model.build(x_train.shape)

  optimizer = make_optimizer(model, args.optimizer, hparams)
  engine = make_engine(args.engine, model, loss_fn, hessian_fn, accuracy, optimizer)

  for epoch in range(args.epochs):
    engine.reset_states()
    print(
      f'Step: {epoch}'
    )
    for x, labels in train_ds:
      engine.train_step(x, labels)
    print(
      'Train '
      f'Loss: {engine.loss_metric.result()}, '
      f'Accuracy: {accuracy.result() * 100}, '
    )
    wandb.log({'accuracy/train': engine.accuracy.result()}, step=epoch)
    wandb.log({'loss/train': engine.loss_metric.result()}, step=epoch)
    engine.summarize(epoch)

    engine.reset_states()
    for x, labels in test_ds:
      engine.test_step(x, labels)
    print(
      'Test '
      f'Loss: {engine.loss_metric.result()}, '
      f'Accuracy: {accuracy.result() * 100}, '
    )
    wandb.log({'accuracy/test': engine.accuracy.result()}, step=epoch)
    wandb.log({'loss/test': engine.loss_metric.result()}, step=epoch)

  wandb.finish()
Exemplo n.º 11
0
def main():

    df_train = pd.read_csv('../data/train.csv')
    df_train['type'] = df_train['before_file_path'].apply(
        lambda x: 'BC' if 'BC' in x else 'LT')
    # df_train['before_file_path'] = df_train['before_file_path'].apply(lambda x: x.replace('.png', '_resize256.png'))
    # df_train['after_file_path'] = df_train['after_file_path'].apply(lambda x: x.replace('.png', '_resize256.png'))
    df_train['splits'] = df_train['before_file_path'].apply(
        lambda x: x.split('adjust/')[-1][:5]
    )  # + '_' + df_train['time_delta'].astype(str)

    train1 = df_train[df_train['type'] == 'BC'].reset_index(drop=True)
    train2 = df_train[df_train['type'] == 'LT'].reset_index(drop=True)

    # df_train = df_train[~df_train['splits'].isin(['BC_03', 'BC_04', 'LT_08', 'LT_05'])].reset_index(drop=True)
    print(df_train.splits.value_counts())

    if config.type == 'BC':
        df_train = df_train[df_train['type'] == 'BC'].reset_index(drop=True)
    elif config.type == 'LT':
        df_train = df_train[df_train['type'] == 'LT'].reset_index(drop=True)

    # skf = StratifiedKFold(n_splits=config.k, random_state=config.seed, shuffle=True)
    # n_splits = list(skf.split(df_train, df_train['splits']))

    # df_train['time_delta'] = np.log(df_train['time_delta'])

    gk = GroupKFold(n_splits=config.k)
    # n_splits = list(gk.split(df_train, y=df_train['time_delta'], groups=df_train['splits']))
    n_splits = list(
        gk.split(train1, y=train1['time_delta'], groups=train1['splits']))
    n_splits2 = list(
        gk.split(train2, y=train2['time_delta'], groups=train2['splits']))
    train1['n_fold'] = -1
    train2['n_fold'] = -1
    for i in range(config.k):
        train1.loc[n_splits[i][1], 'n_fold'] = i
        train2.loc[n_splits2[i][1], 'n_fold'] = i
    # df_train['n_fold'] = -1
    # for i in range(config.k):
    #     df_train.loc[n_splits[i][1], 'n_fold'] = i
    # print(df_train['n_fold'].value_counts())

    for fold in config.training_folds:
        config.start_time = time.strftime('%Y-%m-%d %H:%M',
                                          time.localtime(time.time())).replace(
                                              ' ', '_')

        logger = WandbLogger(
            name=f"{config.start_time}_{config.version}_{config.k}fold_{fold}",
            project='dacon-plant',
            config={
                key: config.__dict__[key]
                for key in config.__dict__.keys() if '__' not in key
            },
        )

        tt = pd.concat([
            train1[train1['n_fold'] != fold], train2[train2['n_fold'] != fold]
        ]).reset_index(drop=True)
        vv = pd.concat([
            train1[train1['n_fold'] == fold], train2[train2['n_fold'] == fold]
        ]).reset_index(drop=True)
        # tt = df_train.loc[df_train['n_fold']!=fold].reset_index(drop=True)#.iloc[:1000]
        # vv = df_train.loc[df_train['n_fold']==fold].reset_index(drop=True)
        print(vv['splits'].value_counts())

        train_transforms = train_get_transforms()
        valid_transforms = valid_get_transforms()

        config.train_dataset = PlantDataset(config,
                                            tt,
                                            mode='train',
                                            transforms=train_transforms)
        config.valid_dataset = PlantDataset(config,
                                            vv,
                                            mode='valid',
                                            transforms=valid_transforms)

        print('train_dataset input shape, label : ',
              config.train_dataset[0]['be_img'].shape,
              config.train_dataset[0]['af_img'].shape,
              config.train_dataset[0]['label'])
        print('valid_dataset input shape, label : ',
              config.valid_dataset[0]['be_img'].shape,
              config.valid_dataset[0]['af_img'].shape,
              config.valid_dataset[0]['label'])

        lr_monitor = LearningRateMonitor(
            logging_interval='epoch')  # ['epoch', 'step']
        checkpoints = ModelCheckpoint(
            'model/' + config.version,
            save_top_k=1,
            monitor='total_val_mse',
            mode='min',
            filename=f'{config.k}fold_{fold}__' +
            '{epoch}_{total_val_loss:.4f}_{total_val_mse:.4f}')

        model = plModel(config)
        trainer = pl.Trainer(
            max_epochs=config.epochs,
            gpus=1,
            log_every_n_steps=50,
            # gradient_clip_val=1000, gradient_clip_algorithm='value', # defalut : [norm, value]
            # amp_backend='native', precision=16, # amp_backend default : native
            callbacks=[checkpoints, lr_monitor],
            logger=logger)

        trainer.fit(model)
        del model, trainer
        wandb.finish()
Exemplo n.º 12
0
 def finish(self):
     if USE_WB:
         wandb.finish()
Exemplo n.º 13
0
wb.login()

wb.init(project='bird_ID', config={'lr': 1e-3, 'bs': 32})
config = wb.config

keras.backend.clear_session()

model = tf.keras.Sequential([
            layers.Conv2D(filters=32, kernel_size=(4,4), strides=1, activation='relu', input_shape=(284, 257, 1)),
            layers.MaxPool2D(pool_size=(4,4)),
            layers.Conv2D(filters=64, kernel_size=(4,4), strides=1, activation='relu'),
            layers.MaxPool2D(pool_size=(4,4)),
            layers.Flatten(),
            layers.Dense(64, activation='relu'),
            layers.Dense(3)
])

model.summary()
model.compile(optimizer=tf.optimizers.Adam(learning_rate=config.lr), loss=losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')

AUTOTUNE = tf.data.AUTOTUNE

train_ds_ = train_ds.shuffle(500, seed=seed).cache().prefetch(AUTOTUNE).batch(config.bs)
val_ds_ = val_ds.shuffle(500, seed=seed).cache().prefetch(AUTOTUNE).batch(config.bs)

#LargeDataset
model.fit(train_ds_, epochs=2, validation_data=val_ds_, callbacks=[WandbCallback()])

wb.finish()

Exemplo n.º 14
0
def finish():
    wandb.finish()
Exemplo n.º 15
0
def main(cfg: DictConfig):
    cur_dir = hydra.utils.get_original_cwd()
    os.chdir(cur_dir)
    seed_everything(cfg.data.seed)

    # wandb
    wandb.init(project='VinBigData-Detection')
    wandb.config.update(dict(cfg.data))
    wandb.config.update(dict(cfg.train))
    wandb.config.update(dict(cfg.aug_kwargs_detection))
    wandb.config.update(dict(cfg.classification_kwargs))

    # omegaconf -> dict
    rep_aug_kwargs = OmegaConf.to_container(cfg.aug_kwargs_detection)

    class_name_dict = {
        0: 'Aortic enlargement',
        1: 'Atelectasis',
        2: 'Calcification',
        3: 'Cardiomegaly',
        4: 'Consolidation',
        5: 'ILD',
        6: 'Infiltration',
        7: 'Lung Opacity',
        8: 'Nodule/Mass',
        9: 'Other lesion',
        10: 'Pleural effusion',
        11: 'Pleural thickening',
        12: 'Pneumothorax',
        13: 'Pulmonary fibrosis',
    }

    # Setting  --------------------------------------------------
    data_dir = cfg.data.data_dir
    output_dir = cfg.data.output_dir
    img_size = cfg.data.img_size
    backbone = cfg.data.backbone
    use_class14 = cfg.data.use_class14

    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)

    if use_class14:
        class_name_dict.update({14: 'No finding'})

    # Register Dataset  --------------------------------------------------
    anno_df = pd.read_csv(os.path.join(data_dir, 'train_wbf_th0.7.csv'))

    if cfg.data.use_class14:
        pass
    else:
        anno_df = anno_df[anno_df['class_id'] != 14].reset_index(drop=True)

    # Extract rad id
    if cfg.data.rad_id != 'all':
        anno_df = anno_df[anno_df['rad_id'].isin(cfg.data.rad_id)].reset_index()

    if debug:
        anno_df = anno_df.head(100)

    # Split train, valid data - random
    if 'valid' in cfg.data.split_method:
        split_rate = float(cfg.data.split_method.split('_')[1]) / 100
        unique_image_ids = anno_df['image_id'].values
        unique_image_ids = np.random.RandomState(cfg.data.seed).permutation(unique_image_ids)
        train_image_ids = unique_image_ids[:int(len(unique_image_ids) * (1 - split_rate))]
        valid_image_ids = unique_image_ids[int(len(unique_image_ids) * (1 - split_rate)):]
        DatasetCatalog.register("xray_valid", lambda d='valid': get_xray_dict(anno_df, data_dir, cfg, valid_image_ids))
        MetadataCatalog.get("xray_valid").set(thing_classes=list(class_name_dict.values()))

    else:
        train_image_ids = anno_df['image_id'].values
    DatasetCatalog.register("xray_train", lambda d='train': get_xray_dict(anno_df, data_dir, cfg, train_image_ids))
    MetadataCatalog.get("xray_train").set(thing_classes=list(class_name_dict.values()))

    DatasetCatalog.register("xray_test", lambda d='test': get_test_xray_dict(data_dir))
    MetadataCatalog.get("xray_test").set(thing_classes=list(class_name_dict.values()))

    # Config  --------------------------------------------------
    detectron2_cfg = get_cfg()
    detectron2_cfg.aug_kwargs = CN(rep_aug_kwargs)
    detectron2_cfg.merge_from_file(model_zoo.get_config_file(backbone))
    detectron2_cfg.DATASETS.TRAIN = ("xray_train",)
    if 'valid' in cfg.data.split_method:
        detectron2_cfg.DATASETS.TEST = ("xray_valid",)
        detectron2_cfg.TEST.EVAL_PERIOD = cfg.train.max_iter // 10
    else:
        detectron2_cfg.DATASETS.TEST = ()
    detectron2_cfg.INPUT.MIN_SIZE_TRAIN = (img_size,)
    detectron2_cfg.INPUT.MAX_SIZE_TRAIN = img_size
    detectron2_cfg.DATALOADER.NUM_WORKERS = cfg.train.num_workers
    detectron2_cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(backbone)
    detectron2_cfg.SOLVER.IMS_PER_BATCH = cfg.train.ims_per_batch
    detectron2_cfg.SOLVER.BASE_LR = cfg.train.lr
    detectron2_cfg.SOLVER.MAX_ITER = cfg.train.max_iter
    detectron2_cfg.SOLVER.LR_SCHEDULER_NAME = "WarmupCosineLR"
    detectron2_cfg.SOLVER.WARMUP_ITERS = 2000
    detectron2_cfg.SOLVER.CHECKPOINT_PERIOD = 200000
    detectron2_cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = cfg.train.batch_size_per_image
    detectron2_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 15 if use_class14 else 14
    detectron2_cfg.OUTPUT_DIR = output_dir
    detectron2_cfg.SEED = cfg.data.seed
    detectron2_cfg.PIXEL_MEAN = [103.530, 116.280, 123.675]
    detectron2_cfg.PIXEL_STD = [1.0, 1.0, 1.0]

    # Train  --------------------------------------------------
    os.makedirs(detectron2_cfg.OUTPUT_DIR, exist_ok=True)
    # trainer = DefaultTrainer(detectron2_cfg)
    trainer = MyTrainer(detectron2_cfg)
    trainer.resume_or_load(resume=True)
    trainer.train()

    # Rename Last Weight
    renamed_model = f"{backbone.split('.')[0].replace('/', '-')}.pth"
    os.rename(os.path.join(cfg.data.output_dir, 'model_final.pth'),
              os.path.join(cfg.data.output_dir, renamed_model))

    # Logging
    for model_path in glob.glob(os.path.join(cfg.data.output_dir, '*.pth')):
        wandb.save(model_path)

    # Inference Setting  ------------------------------------------------------
    detectron2_cfg = get_cfg()
    detectron2_cfg.merge_from_file(model_zoo.get_config_file(backbone))
    detectron2_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 15 if use_class14 else 14
    detectron2_cfg.MODEL.WEIGHTS = os.path.join(output_dir, renamed_model)  # path to the model we just trained
    detectron2_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = cfg.data.score_th   # set a custom testing threshold

    predictor = DefaultPredictor(detectron2_cfg)
    dataset_dicts = get_test_xray_dict(data_dir)

    # Visualize  ------------------------------------------------------
    target_image_ids = ['9a5094b2563a1ef3ff50dc5c7ff71345',
                        '22b8e616a61bbc4caaed0cf23b7159df',
                        '001d127bad87592efe45a5c7678f8b8d',
                        '008b3176a7248a0a189b5731ac8d2e95']

    for th in [0, 0.2, 0.5, 0.7]:
        visualize(target_image_ids, data_dir, output_dir, predictor, score_th=th)

    # Metrics
    if os.path.exists(os.path.join(output_dir, 'metrics.json')):
        metrics_df = pd.read_json(os.path.join(output_dir, 'metrics.json'), orient="records", lines=True)
        mdf = metrics_df.sort_values("iteration")

        mdf3 = mdf[~mdf["bbox/AP75"].isna()].reset_index(drop=True)
        for i in range(len(mdf3)):
            row = mdf3.iloc[i]
            wandb.log({'AP40': row["bbox/AP75"] / 100.})

        best_score = mdf3["bbox/AP75"].max() / 100.
        wandb.log({'Best-AP40-Score': best_score})

    # Inference  ------------------------------------------------------
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    sub = get_submission(dataset_dicts, cfg, predictor, device)

    now = datetime.datetime.now() + datetime.timedelta(hours=9)
    now = now.strftime("%Y%m%d-%H%M%S")

    filename = f'submission_{now}.csv'
    sub.to_csv(os.path.join('./submission', filename), index=False)
    wandb.save(os.path.join('./submission', filename))
    time.sleep(30)

    wandb.finish()
    DatasetCatalog.clear()
Exemplo n.º 16
0
def test_molecule_from_invalid_smiles(mocked_run):
    """Ensures that wandb.Molecule.from_smiles errs if passed an invalid SMILES string"""
    with pytest.raises(ValueError):
        wandb.Molecule.from_smiles("TEST")
    wandb.finish()
Exemplo n.º 17
0
def main(gpu, params):
    """ Loads the dataset and trains the model."""
    rank = params.nr * params.gpus + gpu
    if params.distributed:
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=params.world_size,
                                rank=rank)
    seed_all(SEED)

    # get gpu device
    if params.device == 'gpu':
        device = torch.device(gpu)
    else:
        device = 'cpu'

    # only wandb on main process
    if rank == 0 and params.wandb:
        wandb.init(project='mnmt',
                   entity='nlp-mnmt-project',
                   config={
                       k: v
                       for k, v in params.__dict__.items()
                       if isinstance(v, (float, int, str))
                   })
        config = wandb.config
    logger, params = setup(params)

    # load data and train for required experiment
    if len(params.langs) == 2:
        # bilingual translation

        # load tokenizers if continuing
        if params.checkpoint:
            tokenizers = []
            for lang in params.langs:
                tokenizers.append(
                    Tokenizer.from_file(logger.root_path + '/' + lang +
                                        '_tokenizer.json'))
        else:
            if params.tokenizer is not None:
                if len(params.tokenizer) == 2:
                    tokenizers = [
                        Tokenizer.from_file('pretrained/' + tok + '.json')
                        for tok in params.tokenizer
                    ]
                else:
                    print('Wrong number of tokenizers passed. Retraining.')
                    tokenizers = None
            else:
                tokenizers = None

        train_dataloader, val_dataloader, test_dataloader, _ = preprocess.load_and_preprocess(
            params.langs,
            params.batch_size,
            params.vocab_size,
            params.dataset,
            multi=False,
            path=logger.root_path,
            tokenizer=tokenizers,
            distributed=params.distributed,
            world_size=params.world_size,
            rank=rank)

        train(rank,
              device,
              logger,
              params,
              train_dataloader,
              val_dataloader=val_dataloader,
              verbose=params.verbose)

    elif len(params.langs) > 2:
        # multilingual translation

        # load tokenizers if continuing
        if params.checkpoint:
            tokenizer = Tokenizer.from_file(logger.root_path +
                                            '/multi_tokenizer.json')
        else:
            if params.tokenizer is not None:
                tokenizer = Tokenizer.from_file('pretrained/' +
                                                params.tokenizer + '.json')
            else:
                tokenizer = None

        train_dataloader, val_dataloader, test_dataloader, tokenizer = preprocess.load_and_preprocess(
            params.langs,
            params.batch_size,
            params.vocab_size,
            params.dataset,
            multi=True,
            path=logger.root_path,
            tokenizer=tokenizer,
            distributed=params.distributed,
            world_size=params.world_size,
            rank=rank)

        train(rank,
              device,
              logger,
              params,
              train_dataloader,
              val_dataloader=val_dataloader,
              tokenizer=tokenizer,
              verbose=params.verbose)

    else:
        raise NotImplementedError

    # end wanb process to avoid hanging
    if rank == 0 and params.wandb:
        wandb.finish()
Exemplo n.º 18
0
def test_molecule_from_rdkit_invalid_input(mocked_run):
    """Ensures that wandb.Molecule.from_rdkit errs on invalid input"""
    mol_file_name = "test"
    with pytest.raises(ValueError):
        wandb.Molecule.from_rdkit(mol_file_name)
    wandb.finish()
Exemplo n.º 19
0
def main(args):
    # --- CONFIG
    device = torch.device(
        f"cuda:{args.cuda}"
        if torch.cuda.is_available() and args.cuda >= 0
        else "cpu"
    )

    # --- SCENARIO CREATION
    scenario = SplitCIFAR100(n_experiences=20, return_task_id=True)
    config = {"scenario": "SplitCIFAR100"}

    # MODEL CREATION
    model = MTSimpleCNN()

    # choose some metrics and evaluation method
    loggers = [InteractiveLogger()]
    if args.wandb_project != "":
        wandb_logger = WandBLogger(
            project_name=args.wandb_project,
            run_name="LaMAML_" + config["scenario"],
            config=config,
        )
        loggers.append(wandb_logger)

    eval_plugin = EvaluationPlugin(
        accuracy_metrics(
            minibatch=True, epoch=True, experience=True, stream=True
        ),
        loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        forgetting_metrics(experience=True),
        loggers=loggers,
    )

    # LAMAML STRATEGY
    rs_buffer = ReservoirSamplingBuffer(max_size=200)
    replay_plugin = ReplayPlugin(
        mem_size=200,
        batch_size=10,
        batch_size_mem=10,
        task_balanced_dataloader=False,
        storage_policy=rs_buffer,
    )

    cl_strategy = LaMAML(
        model,
        torch.optim.SGD(model.parameters(), lr=0.1),
        CrossEntropyLoss(),
        n_inner_updates=5,
        second_order=True,
        grad_clip_norm=1.0,
        learn_lr=True,
        lr_alpha=0.25,
        sync_update=False,
        train_mb_size=10,
        train_epochs=10,
        eval_mb_size=100,
        device=device,
        plugins=[replay_plugin],
        evaluator=eval_plugin,
    )

    # TRAINING LOOP
    print("Starting experiment...")
    results = []
    for experience in scenario.train_stream:
        print("Start of experience ", experience.current_experience)
        cl_strategy.train(experience)
        print("Training completed")

        print("Computing accuracy on the whole test set")
        results.append(cl_strategy.eval(scenario.test_stream))

    if args.wandb_project != "":
        wandb.finish()
Exemplo n.º 20
0
def test_html_str(mocked_run):
    html = wandb.Html("<html><body><h1>Hello</h1></body></html>")
    html.bind_to_run(mocked_run, "rad", "summary")
    wandb.Html.seq_to_json([html], mocked_run, "rad", "summary")
    assert os.path.exists(html._path)
    wandb.finish()
Exemplo n.º 21
0
def test_live_policy_policy(mocked_live_policy):
    assert mocked_live_policy.policy == "live"
    wandb.finish()
Exemplo n.º 22
0
def test_object3d_dict(mocked_run):
    obj = wandb.Object3D({"type": "lidar/beta",})
    obj.bind_to_run(mocked_run, "object3D", 0)
    assert obj.to_json(mocked_run)["_type"] == "object3D-file"
    wandb.finish()
Exemplo n.º 23
0
def main():
    ''' set default hyperparams in default_hyperparams.py '''
    parser = argparse.ArgumentParser()

    # Required arguments
    parser.add_argument('-d',
                        '--dataset',
                        choices=wilds.supported_datasets,
                        required=True)
    parser.add_argument('--algorithm',
                        required=True,
                        choices=supported.algorithms)
    parser.add_argument(
        '--root_dir',
        required=True,
        help=
        'The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).'
    )
    parser.add_argument('--pretrained_model_path',
                        default=None,
                        type=str,
                        help="Specify a path to a pretrained model's weights")

    # Dataset
    parser.add_argument(
        '--split_scheme',
        help=
        'Identifies how the train/val/test split is constructed. Choices are dataset-specific.'
    )
    parser.add_argument('--dataset_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument(
        '--download',
        default=False,
        type=parse_bool,
        const=True,
        nargs='?',
        help=
        'If true, tries to download the dataset if it does not exist in root_dir.'
    )
    parser.add_argument(
        '--frac',
        type=float,
        default=1.0,
        help=
        'Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes. Note that this also scales the test set down, so the reported numbers are not comparable with the full test set.'
    )
    parser.add_argument('--version', default=None, type=str)

    # Loaders
    parser.add_argument('--loader_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--train_loader', choices=['standard', 'group'])
    parser.add_argument('--uniform_over_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--distinct_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--n_groups_per_batch', type=int)
    parser.add_argument('--unlabeled_n_groups_per_batch', type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--unlabeled_batch_size', type=int)
    parser.add_argument('--eval_loader',
                        choices=['standard'],
                        default='standard')
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        'Number of batches to process before stepping optimizer and/or schedulers. If > 1, we simulate having a larger effective batch size (though batchnorm behaves differently).'
    )

    # Active Learning
    parser.add_argument('--active_learning',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument(
        '--target_split',
        default="test",
        type=str,
        help=
        'Split from which to sample labeled examples and use as unlabeled data for self-training.'
    )
    parser.add_argument(
        '--use_target_labeled',
        type=parse_bool,
        const=True,
        nargs='?',
        default=True,
        help=
        "If false, we sample target labels and remove them from the eval set, but don't actually train on them."
    )
    parser.add_argument(
        '--use_source_labeled',
        type=parse_bool,
        const=True,
        nargs='?',
        default=False,
        help=
        "Train on labeled source examples (perhaps in addition to labeled target examples.)"
    )
    parser.add_argument(
        '--upsample_target_labeled',
        type=parse_bool,
        const=True,
        nargs='?',
        default=False,
        help=
        "If concatenating source labels, upsample target labels s.t. our labeled batches are 1/2 src and 1/2 tgt."
    )
    parser.add_argument('--selection_function',
                        choices=supported.selection_functions)
    parser.add_argument(
        '--selection_function_kwargs',
        nargs='*',
        action=ParseKwargs,
        default={},
        help=
        "keyword arguments for selection fn passed as key1=value1 key2=value2")
    parser.add_argument(
        '--selectby_fields',
        nargs='+',
        help=
        "If set, acts like a grouper and n_shots are acquired per selection group (e.g. y x hospital selects K examples per y x hospital)."
    )
    parser.add_argument('--n_shots',
                        type=int,
                        help="number of shots (labels) to actively acquire")

    # Model
    parser.add_argument('--model', choices=supported.models)
    parser.add_argument(
        '--model_kwargs',
        nargs='*',
        action=ParseKwargs,
        default={},
        help=
        'keyword arguments for model initialization passed as key1=value1 key2=value2'
    )
    parser.add_argument('--freeze_featurizer',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        help="Only train classifier weights")
    parser.add_argument(
        '--teacher_model_path',
        type=str,
        help=
        'Path to teacher model weights. If this is defined, pseudolabels will first be computed for unlabeled data before anything else runs.'
    )
    parser.add_argument('--dropout_rate', type=float)

    # Transforms
    parser.add_argument('--transform', choices=supported.transforms)
    parser.add_argument('--additional_labeled_transform',
                        type=parse_none,
                        choices=supported.additional_transforms)
    parser.add_argument('--additional_unlabeled_transform',
                        type=parse_none,
                        nargs='+',
                        choices=supported.additional_transforms)
    parser.add_argument(
        '--target_resolution',
        nargs='+',
        type=int,
        help=
        'The input resolution that images will be resized to before being passed into the model. For example, use --target_resolution 224 224 for a standard ResNet.'
    )
    parser.add_argument('--resize_scale', type=float)
    parser.add_argument('--max_token_length', type=int)
    parser.add_argument(
        '--randaugment_n',
        type=int,
        help=
        'N parameter of RandAugment - the number of transformations to apply.')

    # Objective
    parser.add_argument('--loss_function', choices=supported.losses)

    # Algorithm
    parser.add_argument('--groupby_fields', nargs='+')
    parser.add_argument('--group_dro_step_size', type=float)
    parser.add_argument('--coral_penalty_weight', type=float)
    parser.add_argument('--irm_lambda', type=float)
    parser.add_argument('--irm_penalty_anneal_iters', type=int)
    parser.add_argument('--maml_first_order',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--metalearning_k', type=int)
    parser.add_argument('--metalearning_adapt_lr', type=float)
    parser.add_argument('--metalearning_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--self_training_labeled_weight',
                        type=float,
                        help='Weight of labeled loss')
    parser.add_argument('--self_training_unlabeled_weight',
                        type=float,
                        help='Weight of unlabeled loss')
    parser.add_argument('--self_training_threshold', type=float)
    parser.add_argument(
        '--pseudolabel_T2',
        type=float,
        help=
        'Percentage of total iterations at which to end linear scheduling and hold unlabeled weight at the max value'
    )
    parser.add_argument('--soft_pseudolabels',
                        default=False,
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--algo_log_metric')

    # Model selection
    parser.add_argument('--val_metric')
    parser.add_argument('--val_metric_decreasing',
                        type=parse_bool,
                        const=True,
                        nargs='?')

    # Optimization
    parser.add_argument('--n_epochs', type=int)
    parser.add_argument('--optimizer', choices=supported.optimizers)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--weight_decay', type=float)
    parser.add_argument('--max_grad_norm', type=float)
    parser.add_argument('--optimizer_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})

    # Scheduler
    parser.add_argument('--scheduler', choices=supported.schedulers)
    parser.add_argument('--scheduler_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--scheduler_metric_split',
                        choices=['train', 'val'],
                        default='val')
    parser.add_argument('--scheduler_metric_name')

    # Evaluation
    parser.add_argument('--process_outputs_function',
                        choices=supported.process_outputs_functions)
    parser.add_argument('--evaluate_all_splits',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--eval_splits', nargs='+', default=['val', 'test'])
    parser.add_argument(
        '--save_splits',
        nargs='+',
        default=['test'],
        help=
        'If save_pred_step or save_pseudo_step are set, then this sets which splits to save pred / pseudos for. Must be a subset of eval_splits.'
    )
    parser.add_argument('--eval_additional_every',
                        default=1,
                        type=int,
                        help='Eval additional splits every _ training epochs.')
    parser.add_argument('--eval_only',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument(
        '--eval_epoch',
        default=None,
        type=int,
        help=
        'If eval_only is set, then eval_epoch allows you to specify evaluating at a particular epoch. By default, it evaluates the best epoch by validation performance.'
    )

    # Misc
    parser.add_argument('--device', type=int, nargs='+', default=[0])
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--log_dir', default='./logs')
    parser.add_argument('--log_every', default=50, type=int)
    parser.add_argument('--save_model_step', type=int)
    parser.add_argument('--save_pred_step', type=int)
    parser.add_argument('--save_pseudo_step', type=int)
    parser.add_argument('--save_best',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--save_last',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--no_group_logging',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--progress_bar',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument(
        '--resume',
        type=parse_bool,
        const=True,
        nargs='?',
        default=False,
        help=
        'Whether to resume from the most recent saved model in the current log_dir.'
    )

    # Weights & Biases
    parser.add_argument('--use_wandb',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument(
        '--wandb_api_key_path',
        type=str,
        help=
        "Path to Weights & Biases API Key. If use_wandb is set to True and this argument is not specified, user will be prompted to authenticate."
    )
    parser.add_argument('--wandb_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={},
                        help="Will be passed directly into wandb.init().")

    config = parser.parse_args()
    config = populate_defaults(config)

    # Set device
    if torch.cuda.is_available():
        device_count = torch.cuda.device_count()
        if len(config.device) > device_count:
            raise ValueError(
                f"Specified {len(config.device)} devices, but only {device_count} devices found."
            )
        config.use_data_parallel = len(config.device) > 1
        try:
            device_str = ",".join(map(str, config.device))
            config.device = torch.device(f"cuda:{device_str}")
        except RuntimeError as e:
            print(
                f"Failed to initialize CUDA. Using torch.device('cuda') instead. Error: {str(e)}"
            )
            config.device = torch.device("cuda")
    else:
        config.use_data_parallel = False
        config.device = torch.device("cpu")

    ## Initialize logs
    if os.path.exists(config.log_dir) and config.resume:
        resume = True
        config.mode = 'a'
    elif os.path.exists(config.log_dir) and config.eval_only:
        resume = False
        config.mode = 'a'
    else:
        resume = False
        config.mode = 'w'

    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    logger = Logger(os.path.join(config.log_dir, 'log.txt'), config.mode)

    # Record config
    log_config(config, logger)

    # Set random seed
    set_seed(config.seed)

    # Algorithms that use unlabeled data must be run in active learning mode,
    # because otherwise we have no unlabeled data.
    if config.algorithm in ["PseudoLabel", "FixMatch", "NoisyStudent"]:
        assert config.active_learning

    # Data
    full_dataset = wilds.get_dataset(dataset=config.dataset,
                                     version=config.version,
                                     root_dir=config.root_dir,
                                     download=config.download,
                                     split_scheme=config.split_scheme,
                                     **config.dataset_kwargs)

    # In this project, we sometimes train on batches of mixed splits, e.g. some train labeled examples and test labeled examples
    # Within each batch, we may want to sample uniformly across split, or log the train v. test label balance
    # To facilitate this, we'll hack the WILDS dataset to include each point's split in the metadata array
    add_split_to_wilds_dataset_metadata_array(full_dataset)

    # To modify data augmentation, modify the following code block.
    # If you want to use transforms that modify both `x` and `y`,
    # set `do_transform_y` to True when initializing the `WILDSSubset` below.
    train_transform = initialize_transform(
        transform_name=config.transform,
        config=config,
        dataset=full_dataset,
        additional_transform=config.additional_labeled_transform,
        is_training=True)
    eval_transform = initialize_transform(transform_name=config.transform,
                                          config=config,
                                          dataset=full_dataset,
                                          is_training=False)

    # Define any special transforms for the algorithms that use unlabeled data
    # if config.algorithm == "FixMatch":
    #     # For FixMatch, we need our loader to return batches in the form ((x_weak, x_strong), m)
    #     # We do this by initializing a special transform function
    #     unlabeled_train_transform = initialize_transform(
    #         config.transform, config, full_dataset, is_training=True, additional_transform="fixmatch"
    #     )
    # else:
    unlabeled_train_transform = initialize_transform(
        config.transform,
        config,
        full_dataset,
        is_training=True,
        additional_transform=config.additional_unlabeled_transform)

    train_grouper = CombinatorialGrouper(dataset=full_dataset,
                                         groupby_fields=config.groupby_fields)

    datasets = defaultdict(dict)
    for split in full_dataset.split_dict.keys():
        if split == 'train':
            transform = train_transform
            verbose = True
        elif split == 'val':
            transform = eval_transform
            verbose = True
        else:
            transform = eval_transform
            verbose = False

        data = full_dataset.get_subset(split,
                                       frac=config.frac,
                                       transform=transform)

        datasets[split] = configure_split_dict(
            data=data,
            split=split,
            split_name=full_dataset.split_names[split],
            get_train=(split == 'train'),
            get_eval=(split != 'train'),
            verbose=verbose,
            grouper=train_grouper,
            batch_size=config.batch_size,
            config=config)

        pseudolabels = None
        if config.algorithm == "NoisyStudent" and config.target_split == split:
            # Infer teacher outputs on unlabeled examples in sequential order
            # During forward pass, ensure we are not shuffling and not applying strong augs
            print(
                f"Inferring teacher pseudolabels on {config.target_split} for Noisy Student"
            )
            assert config.teacher_model_path is not None
            if not config.teacher_model_path.endswith(".pth"):
                # Use the best model
                config.teacher_model_path = os.path.join(
                    config.teacher_model_path,
                    f"{config.dataset}_seed:{config.seed}_epoch:best_model.pth"
                )
            teacher_model = initialize_model(
                config, infer_d_out(full_dataset)).to(config.device)
            load(teacher_model,
                 config.teacher_model_path,
                 device=config.device)
            # Infer teacher outputs on weakly augmented unlabeled examples in sequential order
            weak_transform = initialize_transform(
                transform_name=config.transform,
                config=config,
                dataset=full_dataset,
                is_training=True,
                additional_transform="weak")
            unlabeled_split_dataset = full_dataset.get_subset(
                split, transform=weak_transform, frac=config.frac)
            sequential_loader = get_eval_loader(
                loader=config.eval_loader,
                dataset=unlabeled_split_dataset,
                grouper=train_grouper,
                batch_size=config.unlabeled_batch_size,
                **config.loader_kwargs)
            pseudolabels = infer_predictions(teacher_model, sequential_loader,
                                             config)
            del teacher_model

        if config.active_learning and config.target_split == split:
            datasets[split]['label_manager'] = LabelManager(
                subset=data,
                train_transform=train_transform,
                eval_transform=eval_transform,
                unlabeled_train_transform=unlabeled_train_transform,
                pseudolabels=pseudolabels)

    if config.use_wandb:
        initialize_wandb(config)

    # Logging dataset info
    # Show class breakdown if feasible
    if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size == 1 and full_dataset.n_classes <= 10:
        log_grouper = CombinatorialGrouper(dataset=full_dataset,
                                           groupby_fields=['y'])
    elif config.no_group_logging:
        log_grouper = None
    else:
        log_grouper = train_grouper
    log_group_data(datasets, log_grouper, logger)

    ## Initialize algorithm
    ## Schedulers are initialized as if we will iterate over "train" split batches.
    ## If we train on another split (e.g. labeled test), we'll re-initialize schedulers later using algorithm.change_n_train_steps()
    algorithm = initialize_algorithm(config=config,
                                     datasets=datasets,
                                     train_grouper=train_grouper)
    if config.freeze_featurizer: freeze_features(algorithm)

    if config.active_learning:
        select_grouper = CombinatorialGrouper(
            dataset=full_dataset, groupby_fields=config.selectby_fields)
        selection_fn = initialize_selection_function(
            config, algorithm, select_grouper, algo_grouper=train_grouper)

    # Resume from most recent model in log_dir
    model_prefix = get_model_prefix(datasets['train'], config)
    if not config.eval_only:
        ## If doing active learning, expects to load a model trained on source
        resume_success = False
        if config.resume:
            save_path = model_prefix + 'epoch:last_model.pth'
            if not os.path.exists(save_path):
                epochs = [
                    int(file.split('epoch:')[1].split('_')[0])
                    for file in os.listdir(config.log_dir)
                    if file.endswith('.pth')
                ]
                if len(epochs) > 0:
                    latest_epoch = max(epochs)
                    save_path = model_prefix + f'epoch:{latest_epoch}_model.pth'
            try:
                prev_epoch, best_val_metric = load(algorithm, save_path,
                                                   config.device)
                # also load previous selections

                epoch_offset = prev_epoch + 1
                config.selection_function_kwargs[
                    'load_selection_path'] = config.log_dir
                logger.write(
                    f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}\n'
                )
                resume_success = True
            except FileNotFoundError:
                pass

        if resume_success == False:
            epoch_offset = 0
            best_val_metric = None

        # Log effective batch size
        logger.write((
            f'\nUsing gradient_accumulation_steps {config.gradient_accumulation_steps} means that'
        ) + (
            f' the effective labeled batch size is {config.batch_size * config.gradient_accumulation_steps}'
        ) + (
            f' and the effective unlabeled batch size is {config.unlabeled_batch_size * config.gradient_accumulation_steps}'
            if config.unlabeled_batch_size else '') + (
                '. Updates behave as if torch loaders have drop_last=False\n'))

        if config.active_learning:
            # create new labeled/unlabeled test splits
            train_split, unlabeled_split = run_active_learning(
                selection_fn=selection_fn,
                datasets=datasets,
                grouper=train_grouper,
                config=config,
                general_logger=logger,
                full_dataset=full_dataset)
            # reset schedulers, which were originally initialized to schedule based on the 'train' split
            # one epoch = one pass over labeled data
            algorithm.change_n_train_steps(
                new_n_train_steps=infer_n_train_steps(
                    datasets[train_split]['train_loader'], config),
                config=config)
        else:
            train_split = "train"
            unlabeled_split = None

        train(algorithm=algorithm,
              datasets=datasets,
              train_split=train_split,
              val_split="val",
              unlabeled_split=unlabeled_split,
              general_logger=logger,
              config=config,
              epoch_offset=epoch_offset,
              best_val_metric=best_val_metric)

    else:
        if config.eval_epoch is None:
            eval_model_path = model_prefix + 'epoch:best_model.pth'
        else:
            eval_model_path = model_prefix + f'epoch:{config.eval_epoch}_model.pth'
        best_epoch, best_val_metric = load(algorithm, eval_model_path,
                                           config.device)
        if config.eval_epoch is None:
            epoch = best_epoch
        else:
            epoch = config.eval_epoch

        if config.active_learning:
            # create new labeled/unlabeled test splits
            config.selection_function_kwargs[
                'load_selection_path'] = config.log_dir
            run_active_learning(selection_fn=selection_fn,
                                datasets=datasets,
                                grouper=train_grouper,
                                config=config,
                                general_logger=logger,
                                full_dataset=full_dataset)

        evaluate(algorithm=algorithm,
                 datasets=datasets,
                 epoch=epoch,
                 general_logger=logger,
                 config=config)

    if config.use_wandb:
        wandb.finish()
    logger.close()
    for split in datasets:
        datasets[split]['eval_logger'].close()
        datasets[split]['algo_logger'].close()
Exemplo n.º 24
0
def test_object3d_dict_invalid(mocked_run):
    with pytest.raises(ValueError):
        obj = wandb.Object3D({"type": "INVALID",})
    wandb.finish()
Exemplo n.º 25
0
def teardown():
    wandb.finish()
    if os.path.isdir("wandb"):
        shutil.rmtree("wandb")
    if os.path.isdir("artifacts"):
        shutil.rmtree("artifacts")
Exemplo n.º 26
0
def test_object3d_dict_invalid_string(mocked_run):
    with pytest.raises(ValueError):
        obj = wandb.Object3D("INVALID")
    wandb.finish()
Exemplo n.º 27
0
def train_model(mods, use_ae, h, w, use_coronal, use_sagital, use_controls,
                latent_dim, batch_size, lr, weight_decay, weight_of_class,
                n_epochs, n_epochs_ae, p, save_masks, parallel,
                experiment_name, temporal_division, seed):

    # set seeds
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    def launch():
        train_loss_history = []
        val_loss_range = [0]
        val_loss_history = []
        model.eval()

        if use_ae:
            if parallel:
                model.module.encoder.load_state_dict(
                    torch.load(
                        f'./checkpoints/{experiment_name}/encoder_{str(idx).zfill(2)}.pth'
                    ))
                for param in model.module.encoder.parameters():
                    param.requires_grad = False
            else:
                model.encoder.load_state_dict(
                    torch.load(
                        f'./checkpoints/{experiment_name}/encoder_{str(idx).zfill(2)}.pth'
                    ))
                for param in model.encoder.parameters():
                    param.requires_grad = False

        overall_val_loss = 0
        for i, batch in enumerate(val_dataloader):
            X_batch, y_batch = batch[0].cuda(), batch[1].cuda()
            logits = model(X_batch)
            loss = criterion(logits[:, 0], y_batch)
            overall_val_loss += loss.item()
        overall_val_loss = overall_val_loss / len(val_dataloader)

        overall_train_loss = 0
        for i, batch in enumerate(train_dataloader):
            X_batch, y_batch = batch[0].cuda(), batch[1].cuda()
            logits = model(X_batch)
            loss = criterion(logits[:, 0], y_batch)
            overall_train_loss += loss.item()
        overall_train_loss = overall_train_loss / len(train_dataloader)

        wandb.log({
            f'val-classification-{idx}': overall_val_loss,
            f'train-classification-{idx}': overall_train_loss
        })

        for epoch in range(n_epochs):
            if use_ae:
                if epoch == 1:
                    for param in model.encoder.parameters():
                        param.requires_grad = True

            model.train()
            overall_train_loss = 0
            for i, batch in enumerate(train_dataloader):
                X_batch, y_batch = batch[0].cuda(), batch[1].cuda()
                y_predicted = model(X_batch)
                loss = criterion(y_predicted[:, 0], y_batch)
                loss.backward()
                overall_train_loss += loss.item()
                optimizer.step()
                optimizer.zero_grad()
                train_loss_history.append(loss.item())

            model.eval()
            overall_loss = 0
            y_pred = []
            y_val = []
            for i, batch in enumerate(val_dataloader):
                X_batch, y_batch = batch[0].cuda(), batch[1].cuda()
                logits = model(X_batch)
                predicted_labels = torch.sigmoid(logits[:, 0])

                loss = criterion(logits[:, 0], y_batch)
                overall_loss += loss.item()
                y_pred += list(predicted_labels.detach().cpu().numpy())
                y_val += list(y_batch.detach().cpu().numpy())
            wandb.log({
                f'val-classification-{idx}':
                overall_loss / len(val_dataloader),
                f'train-classification-{idx}':
                overall_train_loss / len(train_dataloader)
            })

            y_val = np.array(y_val)
            y_pred = np.array(y_pred)

            val_loss_history.append(overall_loss / len(val_dataloader))
            val_loss_range.append(val_loss_range[-1] + len(train_dataloader))

            # fig, ax = plt.subplots(1, 1, figsize=(6, 6))
            # ax.semilogy(np.arange(len(train_loss_history)), train_loss_history, label='train loss')
            # ax.semilogy(val_loss_range, val_loss_history, 'r-*', label='val loss')
            # ax.set_title('Model loss history')
            # ax.legend()
            # fig.savefig(f'./plots/{experiment_name}/patchmodel_loss_{str(idx).zfill(2)}.png')
            # plt.close(fig)

        torch.save(
            model.state_dict(),
            f'./checkpoints/{experiment_name}/model_{str(idx).zfill(2)}.pth')

        top_k_score = calculate_top_k_metric(y_val, y_pred)
        top_k_scores.append(top_k_score)

        wandb.log({f'top_k_scores': top_k_score})

    nb_of_dims = 1 + 1 * int(use_coronal) + 1 * int(use_sagital)
    top_k_scores = []

    for idx in np.arange(NB_OF_FCD_SUBJECTS):
        print(f'Model training, doint subject: ', idx)
        if use_ae:
            train_ae(mods=mods,
                     h=h,
                     w=w,
                     use_coronal=use_coronal,
                     use_sagital=use_sagital,
                     latent_dim=latent_dim,
                     batch_size=batch_size,
                     lr=lr,
                     n_epochs=n_epochs_ae,
                     p=p,
                     loo_idx=idx,
                     parallel=parallel,
                     experiment_name=experiment_name)
        deleted_idxs = [idx]

        if use_ae:
            deleted_idxs += [
                i for i in range(NB_OF_FCD_SUBJECTS, NB_OF_FCD_SUBJECTS +
                                 NB_OF_NOFCD_SUBJECTS)
            ]

        train_dataset = PatchTrainDataset('./data/saved_patches/', True,
                                          2 * mods, h, w, batch_size, idx)
        val_dataset = PatchValDataset('./data/saved_patches/', True, 2 * mods,
                                      h, w, idx, DEFAULT_NB_OF_PATCHES,
                                      batch_size)

        train_dataloader = data.DataLoader(train_dataset,
                                           batch_size=1,
                                           shuffle=True,
                                           num_workers=NUM_WORKERS,
                                           pin_memory=True,
                                           drop_last=False,
                                           collate_fn=dummy_collate)
        val_dataloader = data.DataLoader(val_dataset,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=NUM_WORKERS,
                                         pin_memory=True,
                                         drop_last=False,
                                         collate_fn=dummy_collate)

        model = PatchModel(h, w, mods, nb_of_dims, latent_dim, p).cuda()

        if parallel:
            model = nn.DataParallel(model)

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=lr,
                                     weight_decay=weight_decay)
        weights = torch.FloatTensor([1., weight_of_class]).cuda()
        criterion = lambda output, target: weighted_binary_cross_entropy(
            output, target, weights=weights)

        launch()

        print('Top-k score: ', top_k_scores[-1])

        mask_generator = FCDMaskGenerator(
            h=h,
            w=w,
            mods=mods,
            nb_of_dims=nb_of_dims,
            latent_dim=latent_dim,
            use_coronal=use_coronal,
            use_sagital=use_sagital,
            p=p,
            experiment_name=experiment_name,
            parallel=parallel,
            model_weights=
            f'./checkpoints/{experiment_name}/model_{str(idx).zfill(2)}.pth')

        mask_generator.get_probability_masks(idx, save_masks=save_masks)
    top_k_scores = np.array(top_k_scores)
    wandb.finish()
    return top_k_scores
Exemplo n.º 28
0
def test_object3d_gltf(mocked_run):
    obj = wandb.Object3D(utils.fixture_open("Box.gltf"))
    obj.bind_to_run(mocked_run, "object3D", 0)
    assert obj.to_json(mocked_run)["_type"] == "object3D-file"
    wandb.finish()
Exemplo n.º 29
0
                }
                if not avoid_model_calls:
                    log['image'] = wandb.Image(image, caption=decoded_text)

            wandb.log(log)

    if LR_DECAY and not using_deepspeed:
        # Scheduler is automatically progressed after the step when
        # using DeepSpeed.
        distr_scheduler.step(loss)

    if distr_backend.is_root_worker():
        # save trained model to wandb as an artifact every epoch's end

        model_artifact = wandb.Artifact('trained-dalle',
                                        type='model',
                                        metadata=dict(model_config))
        model_artifact.add_file('dalle.pt')
        run.log_artifact(model_artifact)

if distr_backend.is_root_worker():
    save_model(f'./dalle-final.pt')
    wandb.save('./dalle-final.pt')
    model_artifact = wandb.Artifact('trained-dalle',
                                    type='model',
                                    metadata=dict(model_config))
    model_artifact.add_file('dalle-final.pt')
    run.log_artifact(model_artifact)

    wandb.finish()
def test_model(
    model,
    attention = False
):

    if attention == False:
        wandb.init(config=config_best,  project="CS6910-Assignment-3", entity="rahulsundar")
        config = wandb.config
        wandb.run.name = (
            "Inference_" 
            + str(config.cell_type)
            + dataBase.source_lang
            + str(config.numEncoders)
            + "_"
            + dataBase.target_lang
            + "_"
            + str(config.numDecoders)
            + "_"
            + config.optimiser
            + "_"
            + str(config.epochs)
            + "_"
            + str(config.dropout) 
            + "_"
            + str(config.batch_size)
            + "_"
            + str(config.latentDim)
        )
        wandb.run.save()


        if config.cell_type == "LSTM":
            encoder_inputs = model.input[0]
            
            if config.numEncoders == 1:
                encoder_outputs, state_h_enc, state_c_enc = model.get_layer(name = "lstm").output 
            else:           
                encoder_outputs, state_h_enc, state_c_enc = model.get_layer(name = "lstm_"+ str(config.numEncoders-1)).output

            encoder_states = [state_h_enc, state_c_enc]
            encoder_model = Model(encoder_inputs, encoder_states)

            decoder_inputs = model.input[1]
            decoder_state_input_h = Input(shape=(config.latentDim,), name="input_3")
            decoder_state_input_c = Input(shape=(config.latentDim,), name="input_4")
            decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
            decoder_lstm = model.layers[-3]
            decoder_outputs, state_h_dec, state_c_dec = decoder_lstm( decoder_inputs, initial_state=decoder_states_inputs )
            decoder_states = [state_h_dec, state_c_dec]
            decoder_dense = model.layers[-2]
            decoder_outputs = decoder_dense(decoder_outputs)
            
            decoder_dense = model.layers[-1]
            decoder_outputs = decoder_dense(decoder_outputs)
            decoder_model = Model(
                [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states
            )
        elif config.cell_type == "GRU" or config.cell_type == "RNN":
            encoder_inputs = model.input[0]
            if config.cell_type == "GRU":
                if config.numEncoders == 1:
                    encoder_outputs, state = model.get_layer(name = "gru").output
                else:
                    encoder_outputs, state = model.get_layer(name = "gru_"+ str(config.numEncoders-1)).output
            else:
                if config.numEncoders == 1:
                    encoder_outputs, state = model.get_layer(name = "simple_rnn").output
                else:
                    encoder_outputs, state = model.get_layer(name = "simple_rnn_"+ str(config.numEncoders-1)).output

            encoder_states = [state]

            encoder_model = Model(encoder_inputs, encoder_states)

            decoder_inputs = model.input[1]

            decoder_state = Input(shape=(config.latentDim,), name="input_3")
            decoder_states_inputs = [decoder_state]

            decoder_gru = model.layers[-3]
            (decoder_outputs, state,) = decoder_gru(decoder_inputs, initial_state=decoder_states_inputs)
            decoder_states = [state]
            decoder_dense = model.layers[-2]
            decoder_outputs = decoder_dense(decoder_outputs)
            decoder_dense = model.layers[-1]
            decoder_outputs = decoder_dense(decoder_outputs)
            decoder_model = Model(
                [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states
            )

        def decode_sequence(input_seq):
            # Encode the input as state vectors.
            states_value = encoder_model.predict(input_seq)

            # Generate empty target sequence of length 1.
            target_seq = np.zeros((1, 1, len(dataBase.target_char2int)))
            # Populate the first character of target sequence with the start character.
            target_seq[0, 0, dataBase.target_char2int["\n"]] = 1.0

            # Sampling loop for a batch of sequences
            # (to simplify, here we assume a batch of size 1).
            stop_condition = False
            decoded_sentence = ""
            while not stop_condition:
                if config.cell_type == "LSTM":
                    output_tokens, h, c = decoder_model.predict([target_seq] + states_value)
                elif config.cell_type == "RNN" or config.cell_type == "GRU":
                    states_value = states_value[0].reshape((1, 256))
                    output_tokens, h = decoder_model.predict([target_seq] + [states_value])

                # Sample a token
                sampled_token_index = np.argmax(output_tokens[0, -1, :])
                sampled_char = dataBase.target_int2char[sampled_token_index]
                decoded_sentence += sampled_char

                # Exit condition: either hit max length
                # or find stop character.
                if sampled_char == "\n" or len(decoded_sentence) > 25:
                    stop_condition = True

                # Update the target sequence (of length 1).
                target_seq = np.zeros((1, 1, len(dataBase.target_char2int)))
                target_seq[0, 0, sampled_token_index] = 1.0

                # Update states
                if config.cell_type == "LSTM":
                    states_value = [h, c]
                elif config.cell_type == "RNN" or config.cell_type == "GRU":
                    states_value = [h]
            return decoded_sentence

        acc = 0
        sourcelang = []
        predictions = []
        original = []
        for i, row in dataBase.test.iterrows():
            input_seq = dataBase.test_encoder_input[i : i + 1]
            decoded_sentence = decode_sequence(input_seq)
            og_tokens = [dataBase.target_char2int[x] for x in row["tgt"]]
            predicted_tokens = [dataBase.target_char2int[x] for x in decoded_sentence.rstrip("\n")]
            # if decoded_sentence == row['tgt']:
            #   acc += 1
            sourcelang.append(row['src'])
            original.append(row['tgt'])
            predictions.append(decoded_sentence)

            if og_tokens == predicted_tokens:
                acc += 1

            if i % 100 == 0:
                print(f"Finished {i} examples")
                print(f"Source: {row['src']}")
                print(f"Original: {row['tgt']}")
                print(f"Predicted: {decoded_sentence}")
                print(f"Accuracy: {acc / (i+1)}")
                print(og_tokens)
                print(predicted_tokens)
                

        print(f'Test Accuracy: {acc}')
        wandb.log({'test_accuracy': acc / len(dataBase.test)})
        wandb.finish()
        return acc / len(dataBase.test), sourcelang, original, predictions

    elif attention == True:
        wandb.init(config=config_best_attention2,  project="CS6910-Assignment-3", entity="rahulsundar")
        config = wandb.config
        wandb.run.name = (
            "Inference_WithAttn_" 
            + str(config.cell_type)
            + dataBase.source_lang
            + str(config.numEncoders)
            + "_"
            + dataBase.target_lang
            + "_"
            + str(config.numDecoders)
            + "_"
            + config.optimiser
            + "_"
            + str(config.epochs)
            + "_"
            + str(config.dropout) 
            + "_"
            + str(config.batch_size)
            + "_"
            + str(config.latentDim)
        )
        wandb.run.save()


        if config.cell_type == "LSTM":
            encoder_inputs = model.input[0]
            if config.numEncoders == 1:
                encoder_outputs, state_h_enc, state_c_enc = model.get_layer(name = "lstm").output 
            else:           
                encoder_outputs, state_h_enc, state_c_enc = model.get_layer(name = "lstm_"+ str(config.numEncoders-1)).output
            encoder_first_outputs, _, _ = model.get_layer(name = "lstm").output
            encoder_states = [state_h_enc, state_c_enc]
            encoder_model = Model(encoder_inputs, encoder_states)

            decoder_inputs = model.input[1]
            decoder_state_input_h = Input(shape=(config.latentDim,), name="input_3")
            decoder_state_input_c = Input(shape=(config.latentDim,), name="input_4")
            decoder_hidden_state = Input(shape=(None,config["latentDim"]), name = "input_5")
            decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
            #decoder_lstm = model.layers[-3]
            decoder_lstm = model.get_layer(name = "lstm_"+ str(config.numEncoders + config.numDecoders -1))
            decoder_outputs, state_h_dec, state_c_dec = decoder_lstm( decoder_inputs, initial_state=decoder_states_inputs )
            decoder_states = [state_h_dec, state_c_dec]

            attention_layer = model.get_layer(name = "attention_layer")#AttentionLayer(name='attention_layer')
            attention_out, attention_states = attention_layer([encoder_first_outputs, decoder_outputs])


            decoder_concat_input = Concatenate(axis=-1, name='concat_layer')([decoder_outputs, attention_out])
            
            decoder_dense = model.layers[-2]
            decoder_time = TimeDistributed(decoder_dense)
            hidden_outputs = decoder_time(decoder_concat_input)
            decoder_dense = model.layers[-1]
            decoder_outputs = decoder_dense(hidden_outputs)

            decoder_model = Model(inputs = [decoder_inputs] + [decoder_hidden_state , decoder_states_inputs], outputs = [decoder_outputs] + decoder_states)
            #decoder_model = Model([decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states )
            
        elif config.cell_type == "GRU" or config.cell_type == "RNN":
            encoder_inputs = model.input[0]
            if config.cell_type == "GRU":
                if config.numEncoders == 1:
                    encoder_outputs, state = model.get_layer(name = "gru").output
                else:
                    encoder_outputs, state = model.get_layer(name = "gru_"+ str(config.numEncoders-1)).output
                encoder_first_outputs, _ = model.get_layer(name = "gru").output
            else:
                if config.numEncoders == 1:
                    encoder_outputs, state = model.get_layer(name = "simple_rnn").output
                else:
                    encoder_outputs, state = model.get_layer(name = "simple_rnn_"+ str(config.numEncoders-1)).output
                encoder_first_outputs, _ = model.get_layer(name = "simple_rnn").output
            encoder_states = [state]

            encoder_model = Model(encoder_inputs, outputs = [encoder_first_outputs, encoder_outputs] + encoder_states)

            decoder_inputs = model.input[1]

            decoder_state = Input(shape=(config.latentDim,), name="input_3")
            decoder_hidden_state = Input(shape=(None,config["latentDim"]), name = "input_4")
            decoder_states_inputs = [decoder_state]

            if config.cell_type == "GRU":
                decoder_gru = model.get_layer(name = "gru_"+ str(config.numEncoders + config.numDecoders -1))#model.layers[-3]
                (decoder_outputs, state) = decoder_gru(decoder_inputs, initial_state=decoder_states_inputs)
                decoder_states = [state]

            else:
                decoder_gru = model.get_layer(name = "simple_rnn_"+ str(config.numEncoders + config.numDecoders -1))#model.layers[-3]
                (decoder_outputs, state) = decoder_gru(decoder_inputs, initial_state=decoder_states_inputs)
                decoder_states = [state]

                    
            attention_layer = AttentionLayer(name='attention_layer')
            #decoder_outputs_att = decoder_ouputs
            attention_out, attention_states = attention_layer([decoder_hidden_state, decoder_outputs])

            decoder_concat_input = Concatenate(axis=-1, name='concat_layer')([decoder_outputs, attention_out])

            decoder_dense = model.layers[-2]
            decoder_time = TimeDistributed(decoder_dense)
            hidden_outputs = decoder_time(decoder_concat_input)
            decoder_dense = model.layers[-1]
            decoder_outputs = decoder_dense(hidden_outputs)

            decoder_model = Model(inputs = [decoder_inputs] + [decoder_hidden_state , decoder_states_inputs], outputs = [decoder_outputs] + decoder_states)
            
        def decode_sequence(input_seq):
            # Encode the input as state vectors.
            encoder_first_outputs, _, states_value = encoder_model.predict(input_seq)

            # Generate empty target sequence of length 1.
            target_seq = np.zeros((1, 1, len(dataBase.target_char2int)))
            # Populate the first character of target sequence with the start character.
            target_seq[0, 0, dataBase.target_char2int["\n"]] = 1.0

            # Sampling loop for a batch of sequences
            # (to simplify, here we assume a batch of size 1).
            stop_condition = False
            decoded_sentence = ""
            attention_weights = []
            while not stop_condition:
                if config.cell_type == "LSTM":
                    output_tokens, h, c = decoder_model.predict([target_seq, encoder_first_outputs] + states_value)
                elif config.cell_type == "RNN" or config.cell_type == "GRU":
                    states_value = states_value[0].reshape((1, config.latentDim))
                    output_tokens, h = decoder_model.predict([target_seq] + [encoder_first_outputs] + [states_value])
                #dec_ind = np.argmax(output_tokens, axis=-1)[0, 0]
                #attention_weights.append((dec_ind, attn_states))
                # Sample a token
                sampled_token_index = np.argmax(output_tokens[0, -1, :])
                sampled_char = dataBase.target_int2char[sampled_token_index]
                decoded_sentence += sampled_char

                # Exit condition: either hit max length
                # or find stop character.
                if sampled_char == "\n" or len(decoded_sentence) > 25:
                    stop_condition = True

                # Update the target sequence (of length 1).
                target_seq = np.zeros((1, 1, len(dataBase.target_char2int)))
                target_seq[0, 0, sampled_token_index] = 1.0

                # Update states
                if config.cell_type == "LSTM":
                    states_value = [h, c]
                elif config.cell_type == "RNN" or config.cell_type == "GRU":
                    states_value = [h]
            return decoded_sentence #, attention_weights

        acc = 0
        sourcelang = []
        predictions = []
        original = []
        #attention_weights_test = []
        for i, row in dataBase.test.iterrows():
            input_seq = dataBase.test_encoder_input[i : i + 1]
            decoded_sentence, attention_weights = decode_sequence(input_seq)
            og_tokens = [dataBase.target_char2int[x] for x in row["tgt"]]
            predicted_tokens = [dataBase.target_char2int[x] for x in decoded_sentence.rstrip("\n")]
            # if decoded_sentence == row['tgt']:
            #   acc += 1
            sourcelang.append(row['src'])
            original.append(row['tgt'])
            predictions.append(decoded_sentence)
            #attention_weights_test.append(attention_weights)
            if og_tokens == predicted_tokens:
                acc += 1

            if i % 100 == 0:
                print(f"Finished {i} examples")
                print(f"Source: {row['src']}")
                print(f"Original: {row['tgt']}")
                print(f"Predicted: {decoded_sentence}")
                print(f"Accuracy: {acc / (i+1)}")
                print(og_tokens)
                print(predicted_tokens)
                

        print(f'Test Accuracy: {acc}')
        wandb.log({'test_accuracy': acc / len(dataBase.test)})
        wandb.finish()
        return acc / len(dataBase.test) , sourcelang, original, predictions #, attention_weights_test