Exemple #1
0
class FNN: 
    def __init__(self):
        ## Device configuration
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    def set_data(self, features, targets, D, denom_sq):
        self.features_np = features
        self.targets_np = targets
        self.D_np = D
        self.inv_denom_sq = denom_sq**-1
    
    def train(self, config):
        ## Internal config
        self.config = {}
        self.config['num_epochs']       = 5000
        self.config['n_hidden']         = 2
        self.config['hidden_size']      = 40
        self.config['batch_size']       = 10
        self.config['lr']               = 1e-2
        self.config['regularization']   = 1e-10
        # Overwrite internal config values given in the external config
        if config:
            for key in config.keys():
                self.config[key] = config[key]
        
        # Assume we're using ray.tune at first
        self.tuning = True
        
        ## Model
        self.config['input_size'] = self.features_np['train'].shape[1]
        self.config['output_size'] = self.targets_np['train'].shape[1]
        self.model = Model(self.config).to(self.device)
        
        ## Data loaders
        self.batch_size = self.config['batch_size']
        self.train_loader = data_loader.create_loader(
            self.features_np['train'],
            self.targets_np['train'],
            self.batch_size,
            True)
        self.validate_loader  = data_loader.create_loader(
            self.features_np['validate'],
            self.targets_np['validate'],
            self.features_np['validate'].shape[0], # use all test samples
            False)                             # don't shuffle
        
        ## Hyperparameters
        self.num_epochs = self.config['num_epochs']
        self.learning_rate = self.config['lr']
        
        ## Loss and optimizer
        self.criterion = self.eps_reg_sq
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, eps=1e-8, weight_decay=self.config['regularization'])
        lambdaLR = lambda epoch: 1 / (1 + 0.005*epoch)
        self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambdaLR)
        
        self.train_start()
    
    def train_start(self):
        ## Train
        early_stop = False
        self.D = torch.from_numpy(self.D_np).float().to(self.device)
        
        for epoch in range(self.num_epochs):
            for i, (features, targets) in enumerate(self.train_loader):
                self.model.train()
                self.optimizer.zero_grad()
                
                # Move tensors to the configured device
                features = features.to(self.device)
                targets  =  targets.to(self.device)
                
                # Forward pass
                outputs = self.model(features)
                loss = self.criterion(outputs, targets) ** 0.5
                if torch.isnan(loss):
                    print('Something went nan, stopping')
                    early_stop = True
                    break # break out of this batch

                # Backward and optimize
                loss.backward()
                self.optimizer.step()
            
            if early_stop:
                break # break out of this epoch
                
            self.scheduler.step()
                
            if epoch%10==0 or epoch==self.num_epochs-1:
                validate_loss  = self.get_loss(self.validate_loader)
                train_loss = self.get_loss(self.train_loader)
                print('eps_reg: Epoch [{}/{}], LR: {:.2e}, Train loss: {:.2e}, Validate loss: {:.2e}'
                    .format(epoch+1, self.num_epochs, self.scheduler.get_lr()[0], train_loss.item()**0.5, validate_loss.item()**0.5))
                
                if self.tuning:
                    try:
                        tune.track.log(mean_loss = validate_loss.item(), episodes_this_iter = 10)
                    except:
                        self.tuning = False
        return self

    def eps_reg_sq(self, outputs, targets):
        return torch.sum((self.D*(targets - outputs)) ** 2) * self.inv_denom_sq / targets.shape[0]
        
    def get_loss(self, loader):
        with torch.no_grad():
            self.model.eval()
            loss = 0.0
            for features, targets in loader:
                features = features.to(self.device)
                targets = targets.to(self.device)
                outputs = self.model(features)
                loss += self.criterion(outputs, targets)
            return loss/len(loader)

    def evaluate(self, features):
        with torch.no_grad():
            self.model.eval()
            output = self.model(torch.tensor(features).float())
            u_rb = output.numpy()
            return u_rb
    
    def save(self, model_dir, component):
        try:
            path_config     = os.path.join(tune.track.trial_dir(),'config')
            path_state_dict = os.path.join(tune.track.trial_dir(),'state_dict')
        except:
            # not tuning
            path_config     = os.path.join(model_dir, 'FNN', component,'config')
            path_state_dict = os.path.join(model_dir, 'FNN', component,'state_dict')
        with open(path_config, 'wb+') as f:
            pickle.dump(self.config, f)
        
        torch.save(self.model.state_dict(), path_state_dict)
    
    def load(self, model_dir, component):
        '''
        Find and loads the best model from ray.tune analysis results.
        '''
        try:
            path_analysis = os.path.join(model_dir,'FNN',component)
            analysis = tune.Analysis(path_analysis)
            df_temp = analysis.dataframe()
            idx = df_temp['mean_loss'].idxmin()
            logdir = df_temp.loc[idx]['logdir']
            path_config     = os.path.join(logdir,'config')
            path_state_dict = os.path.join(logdir,'state_dict')
        except:
            # no tuning records
            path_config     = os.path.join(model_dir, 'FNN', component,'config')
            path_state_dict = os.path.join(model_dir, 'FNN', component,'state_dict')
            
        
        with open(path_config, 'rb') as f:
            config = pickle.load(f)
            self.model = Model(config).to(self.device)
        
        state_dict = torch.load(path_state_dict,
                                map_location=torch.device('cpu'))
        self.model.load_state_dict(state_dict)
Exemple #2
0
def train_domain_classifier(
        model: torch.nn.Module,
        train_dl: DataLoader,
        optimizer: torch.optim.Optimizer,
        scheduler: LambdaLR,
        validation_evaluator: MultiDatasetClassificationEvaluator,
        n_epochs: int,
        device: AnyStr,
        class_weights: List,
        log_interval: int = 1,
        patience: int = 10,
        model_dir: str = "wandb_local",
        gradient_accumulation: int = 1,
        domain_name: str = ''):
    #best_loss = float('inf')
    best_acc = 0.0
    patience_counter = 0

    epoch_counter = 0
    total = sum(len(dl) for dl in train_dls)
    loss_fn = torch.nn.CrossEntropyLoss(
        weight=torch.FloatTensor(class_weights).to(device))

    # Main loop
    while epoch_counter < n_epochs:
        for i, batch in enumerate(tqdm(train_dl)):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids = batch[0]
            masks = batch[1]
            labels = batch[2]
            # Testing with random domains to see if any effect
            #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
            domains = batch[3]

            logits = model(input_ids, attention_mask=masks)[0]
            loss = loss_fn(logits, domains)
            loss = loss / gradient_accumulation

            #if i % gradient_accumulation == 0:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if scheduler is not None:
                scheduler.step()

        gc.collect()

        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
        print(f"Validation acc: {acc}")

        # Saving the best model and early stopping
        #if val_loss < best_loss:
        if acc > best_acc:
            best_model = model.state_dict()
            best_acc = acc
            torch.save(
                model.state_dict(),
                f'{model_dir}/{Path(wandb.run.dir).name}/model_domainclassifier_{domain_name}.pth'
            )
            patience_counter = 0
        else:
            patience_counter += 1
            # Stop training once we have lost patience
            if patience_counter == patience:
                break

        gc.collect()
        epoch_counter += 1
Exemple #3
0
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
          model, interp, criterion, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ':3.2f')
    losses_t = AverageMeter('Loss (t)', ':3.2f')
    losses_entropy_t = AverageMeter('Entropy (t)', ':3.2f')
    accuracies_s = Meter('Acc (s)', ':3.2f')
    accuracies_t = Meter('Acc (t)', ':3.2f')
    iou_s = Meter('IoU (s)', ':3.2f')
    iou_t = Meter('IoU (t)', ':3.2f')

    confmat_s = ConfusionMatrix(model.num_classes)
    confmat_t = ConfusionMatrix(model.num_classes)
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses_s, losses_t, losses_entropy_t,
         accuracies_s, accuracies_t, iou_s, iou_t],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, label_s = next(train_source_iter)
        x_t, label_t = next(train_target_iter)

        x_s = x_s.to(device)
        label_s = label_s.long().to(device)
        x_t = x_t.to(device)
        label_t = label_t.long().to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s = model(x_s)
        pred_s = interp(y_s)
        loss_cls_s = criterion(pred_s, label_s)
        loss_cls_s.backward()

        y_t = model(x_t)
        pred_t = interp(y_t)
        loss_cls_t = criterion(pred_t, label_t)
        loss_entropy_t = robust_entropy(y_t, args.ita)
        (args.entropy_weight * loss_entropy_t).backward()

        # compute gradient and do SGD step
        optimizer.step()
        lr_scheduler.step()

        # measure accuracy and record loss
        losses_s.update(loss_cls_s.item(), x_s.size(0))
        losses_t.update(loss_cls_t.item(), x_s.size(0))
        losses_entropy_t.update(loss_entropy_t.item(), x_s.size(0))

        confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())
        confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten())
        acc_global_s, acc_s, iu_s = confmat_s.compute()
        acc_global_t, acc_t, iu_t = confmat_t.compute()
        accuracies_s.update(acc_s.mean().item())
        accuracies_t.update(acc_t.mean().item())
        iou_s.update(iu_s.mean().item())
        iou_t.update(iu_t.mean().item())

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            if visualize is not None:
                visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i))
                visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i))
Exemple #4
0
    params = list(model.named_parameters())
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 
            'weight_decay': 0.01
        },
        {
            'params': [p for n, p in params if any(nd in n for nd in no_decay)],
            'weight_decay': 0.0
        }
    ]
    optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.lr, 
                          bias_correction=False)
    scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 1/(1 + 0.05*epoch))

    ## DATA
    train_loader, val_loader, test_loader = get_data(args)

    for epoch in range(args.epochs):
        train(model, optimizer, train_loader, epoch, args)
        if val_loader is not None:
            loss, f1 = evaluate(model, val_loader, args)
            print('val_loss: {:.5f}, val_f1: {:.5f}'.format(loss, f1))
            #print('val_loss: {:.5f}, classification: \n{}'.format(loss, f1))
        scheduler.step()

    if test_loader is not None:
        loss, f1 = evaluate(model, test_loader, args)
        print('test_loss: {:.5f}, test_f1: {:.5f}'.format(loss, f1))
        #print('test_loss: {:.5f}, classification: \n{}'.format(loss, f1))
Exemple #5
0
    def train(self) -> None:
        r"""Main method for DD-PPO.

        Returns:
            None
        """
        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend)
        add_signal_handlers()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore("rollout_tracker",
                                                      tcp_store)
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        self.config.SIMULATOR_GPU_ID = self.local_rank
        # Multiply by the number of simulators to make sure they also get unique seeds
        self.config.TASK_CONFIG.SEED += (self.world_rank *
                                         self.config.NUM_PROCESSES)
        self.config.freeze()

        random.seed(self.config.TASK_CONFIG.SEED)
        np.random.seed(self.config.TASK_CONFIG.SEED)
        torch.manual_seed(self.config.TASK_CONFIG.SEED)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        self.envs = construct_envs(
            self.config,
            get_env_class(self.config.ENV_NAME),
            workers_ignore_signals=True,
        )

        ppo_cfg = self.config.RL.PPO
        if (not os.path.isdir(self.config.CHECKPOINT_FOLDER)
                and self.world_rank == 0):
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(ppo_cfg)
        self.agent.init_distributed(find_unused_params=True)

        if self.world_rank == 0:
            logger.info("agent number of trainable parameters: {}".format(
                sum(param.numel() for param in self.agent.parameters()
                    if param.requires_grad)))

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        obs_space = self.envs.observation_spaces[0]
        if self._static_encoder:
            self._encoder = self.actor_critic.net.visual_encoder
            obs_space = SpaceDict({
                "visual_features":
                spaces.Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=self._encoder.output_shape,
                    dtype=np.float32,
                ),
                **obs_space.spaces,
            })
            with torch.no_grad():
                batch["visual_features"] = self._encoder(batch)

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            obs_space,
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_layers=self.actor_critic.net.num_recurrent_layers,
        )
        rollouts.to(self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1, device=self.device),
            reward=torch.zeros(self.envs.num_envs, 1, device=self.device),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        start_update = 0
        prev_time = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"])
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_update = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        with (TensorboardWriter(self.config.TENSORBOARD_DIR,
                                flush_secs=self.flush_secs)
              if self.world_rank == 0 else contextlib.suppress()) as writer:
            for update in range(start_update, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set() and self.world_rank == 0:
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            ))

                    requeue_job()
                    return

                count_steps_delta = 0
                self.agent.eval()
                for step in range(ppo_cfg.num_steps):

                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps_delta += delta_steps

                    # This is where the preemption of workers happens.  If a
                    # worker detects it will be a straggler, it preempts itself!
                    if (step >=
                            ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
                        ) and int(num_rollouts_done_store.get("num_done")) > (
                            self.config.RL.DDPPO.sync_frac * self.world_size):
                        break

                num_rollouts_done_store.add("num_done", 1)

                self.agent.train()
                if self._static_encoder:
                    self._encoder.eval()

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time

                stats_ordering = list(sorted(running_episode_stats.keys()))
                stats = torch.stack(
                    [running_episode_stats[k] for k in stats_ordering], 0)
                distrib.all_reduce(stats)

                for i, k in enumerate(stats_ordering):
                    window_episode_stats[k].append(stats[i].clone())

                stats = torch.tensor(
                    [value_loss, action_loss, count_steps_delta],
                    device=self.device,
                )
                distrib.all_reduce(stats)
                count_steps += stats[2].item()

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                    losses = [
                        stats[0].item() / self.world_size,
                        stats[1].item() / self.world_size,
                    ]
                    deltas = {
                        k: ((v[-1] - v[0]).sum().item()
                            if len(v) > 1 else v[0].sum().item())
                        for k, v in window_episode_stats.items()
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    writer.add_scalar(
                        "reward",
                        deltas["reward"] / deltas["count"],
                        count_steps,
                    )

                    # Check to see if there are any metrics
                    # that haven't been logged yet
                    metrics = {
                        k: v / deltas["count"]
                        for k, v in deltas.items()
                        if k not in {"reward", "count"}
                    }
                    if len(metrics) > 0:
                        writer.add_scalars("metrics", metrics, count_steps)

                    writer.add_scalars(
                        "losses",
                        {k: l
                         for l, k in zip(losses, ["value", "policy"])},
                        count_steps,
                    )

                    # log stats
                    if update > 0 and update % self.config.LOG_INTERVAL == 0:
                        logger.info("update: {}\tfps: {:.3f}\t".format(
                            update,
                            count_steps /
                            ((time.time() - t_start) + prev_time),
                        ))

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(update, env_time, pth_time,
                                                count_steps))
                        logger.info("Average window size: {}  {}".format(
                            len(window_episode_stats["count"]),
                            "  ".join(
                                "{}: {:.3f}".format(k, v / deltas["count"])
                                for k, v in deltas.items() if k != "count"),
                        ))

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"ckpt.{count_checkpoints}.pth",
                            dict(step=count_steps),
                        )
                        count_checkpoints += 1

            self.envs.close()
Exemple #6
0
def main(args):
    logger = CompleteLogger(args.log, args.phase)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = T.Compose([
        T.RandomResizedCrop(size=args.train_size,
                            ratio=args.resize_ratio,
                            scale=(0.5, 1.)),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(root=args.source_root,
                                          transforms=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(root=args.target_root,
                                          transforms=train_transform)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # define networks (both generators and discriminators)
    netG_S2T = cyclegan.generator.__dict__[args.netG](
        ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
    netG_T2S = cyclegan.generator.__dict__[args.netG](
        ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
    netD_S = cyclegan.discriminator.__dict__[args.netD](
        ndf=args.ndf, norm=args.norm).to(device)
    netD_T = cyclegan.discriminator.__dict__[args.netD](
        ndf=args.ndf, norm=args.norm).to(device)

    # create image buffer to store previously generated images
    fake_S_pool = ImagePool(args.pool_size)
    fake_T_pool = ImagePool(args.pool_size)

    # define optimizer and lr scheduler
    optimizer_G = Adam(itertools.chain(netG_S2T.parameters(),
                                       netG_T2S.parameters()),
                       lr=args.lr,
                       betas=(args.beta1, 0.999))
    optimizer_D = Adam(itertools.chain(netD_S.parameters(),
                                       netD_T.parameters()),
                       lr=args.lr,
                       betas=(args.beta1, 0.999))
    lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs
                                                ) / float(args.epochs_decay)
    lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function)
    lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function)

    # optionally resume from a checkpoint
    if args.resume:
        print("Resume from", args.resume)
        checkpoint = torch.load(args.resume, map_location='cpu')
        netG_S2T.load_state_dict(checkpoint['netG_S2T'])
        netG_T2S.load_state_dict(checkpoint['netG_T2S'])
        netD_S.load_state_dict(checkpoint['netD_S'])
        netD_T.load_state_dict(checkpoint['netD_T'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D'])
        lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])
        lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.phase == 'test':
        transform = T.Compose([
            T.Resize(image_size=args.test_input_size),
            T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),
        ])
        train_source_dataset.translate(transform, args.translated_root)
        return

    # define loss function
    criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss()
    criterion_cycle = nn.L1Loss()
    criterion_identity = nn.L1Loss()

    # define visualization function
    tensor_to_image = Compose(
        [Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
         ToPILImage()])

    def visualize(image, name):
        """
        Args:
            image (tensor): image in shape 3 x H x W
            name: name of the saving image
        """
        tensor_to_image(image).save(
            logger.get_image_path("{}.png".format(name)))

    # start training
    for epoch in range(args.start_epoch, args.epochs + args.epochs_decay):
        logger.set_epoch(epoch)
        print(lr_scheduler_G.get_lr())

        # train for one epoch
        train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S,
              netD_T, criterion_gan, criterion_cycle, criterion_identity,
              optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch,
              visualize, args)

        # update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D.step()

        # save checkpoint
        torch.save(
            {
                'netG_S2T': netG_S2T.state_dict(),
                'netG_T2S': netG_T2S.state_dict(),
                'netD_S': netD_S.state_dict(),
                'netD_T': netD_T.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'optimizer_D': optimizer_D.state_dict(),
                'lr_scheduler_G': lr_scheduler_G.state_dict(),
                'lr_scheduler_D': lr_scheduler_D.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))

    if args.translated_root is not None:
        transform = T.Compose([
            T.Resize(image_size=args.test_input_size),
            T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),
        ])
        train_source_dataset.translate(transform, args.translated_root)

    logger.close()
    def train(self) -> None:
        r"""Main method for training DD/PPO.

        Returns:
            None
        """

        self._init_train()

        count_checkpoints = 0
        prev_time = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: 1 - self.percent_done(),
        )

        resume_state = load_resume_state(self.config)
        if resume_state is not None:
            self.agent.load_state_dict(resume_state["state_dict"])
            self.agent.optimizer.load_state_dict(resume_state["optim_state"])
            lr_scheduler.load_state_dict(resume_state["lr_sched_state"])

            requeue_stats = resume_state["requeue_stats"]
            self.env_time = requeue_stats["env_time"]
            self.pth_time = requeue_stats["pth_time"]
            self.num_steps_done = requeue_stats["num_steps_done"]
            self.num_updates_done = requeue_stats["num_updates_done"]
            self._last_checkpoint_percent = requeue_stats[
                "_last_checkpoint_percent"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            prev_time = requeue_stats["prev_time"]

            self.running_episode_stats = requeue_stats["running_episode_stats"]
            self.window_episode_stats.update(
                requeue_stats["window_episode_stats"])

        ppo_cfg = self.config.RL.PPO

        with (TensorboardWriter(self.config.TENSORBOARD_DIR,
                                flush_secs=self.flush_secs)
              if rank0_only() else contextlib.suppress()) as writer:
            while not self.is_done():
                profiling_wrapper.on_start_step()
                profiling_wrapper.range_push("train update")

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * (
                        1 - self.percent_done())

                if rank0_only() and self._should_save_resume_state():
                    requeue_stats = dict(
                        env_time=self.env_time,
                        pth_time=self.pth_time,
                        count_checkpoints=count_checkpoints,
                        num_steps_done=self.num_steps_done,
                        num_updates_done=self.num_updates_done,
                        _last_checkpoint_percent=self._last_checkpoint_percent,
                        prev_time=(time.time() - self.t_start) + prev_time,
                        running_episode_stats=self.running_episode_stats,
                        window_episode_stats=dict(self.window_episode_stats),
                    )

                    save_resume_state(
                        dict(
                            state_dict=self.agent.state_dict(),
                            optim_state=self.agent.optimizer.state_dict(),
                            lr_sched_state=lr_scheduler.state_dict(),
                            config=self.config,
                            requeue_stats=requeue_stats,
                        ),
                        self.config,
                    )

                if EXIT.is_set():
                    profiling_wrapper.range_pop()  # train update

                    self.envs.close()

                    requeue_job()

                    return

                self.agent.eval()
                count_steps_delta = 0
                profiling_wrapper.range_push("rollouts loop")

                profiling_wrapper.range_push("_collect_rollout_step")
                for buffer_index in range(self._nbuffers):
                    self._compute_actions_and_step_envs(buffer_index)

                for step in range(ppo_cfg.num_steps):
                    is_last_step = (self.should_end_early(step + 1)
                                    or (step + 1) == ppo_cfg.num_steps)

                    for buffer_index in range(self._nbuffers):
                        count_steps_delta += self._collect_environment_result(
                            buffer_index)

                        if (buffer_index + 1) == self._nbuffers:
                            profiling_wrapper.range_pop(
                            )  # _collect_rollout_step

                        if not is_last_step:
                            if (buffer_index + 1) == self._nbuffers:
                                profiling_wrapper.range_push(
                                    "_collect_rollout_step")

                            self._compute_actions_and_step_envs(buffer_index)

                    if is_last_step:
                        break

                profiling_wrapper.range_pop()  # rollouts loop

                if self._is_distributed:
                    self.num_rollouts_done_store.add("num_done", 1)

                (
                    value_loss,
                    action_loss,
                    dist_entropy,
                ) = self._update_agent()

                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()  # type: ignore

                self.num_updates_done += 1
                losses = self._coalesce_post_step(
                    dict(value_loss=value_loss, action_loss=action_loss),
                    count_steps_delta,
                )

                self._training_log(writer, losses, prev_time)

                # checkpoint model
                if rank0_only() and self.should_checkpoint():
                    self.save_checkpoint(
                        f"ckpt.{count_checkpoints}.pth",
                        dict(
                            step=self.num_steps_done,
                            wall_time=(time.time() - self.t_start) + prev_time,
                        ),
                    )
                    count_checkpoints += 1

                profiling_wrapper.range_pop()  # train update

            self.envs.close()
Exemple #8
0
    def train(self) -> None:
        r"""Main method for DD-PPO SLAM.

        Returns:
            None
        """

        #####################################################################
        ## init distrib and configuration #####################################################################
        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend
        )
        # self.local_rank = 1
        add_signal_handlers()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore(
            "rollout_tracker", tcp_store
        )
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank() # server number
        self.world_size = distrib.get_world_size() 

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank # gpu number in one server
        self.config.SIMULATOR_GPU_ID = self.local_rank
        print("********************* TORCH_GPU_ID: ", self.config.TORCH_GPU_ID)
        print("********************* SIMULATOR_GPU_ID: ", self.config.SIMULATOR_GPU_ID)

        # Multiply by the number of simulators to make sure they also get unique seeds
        self.config.TASK_CONFIG.SEED += (
            self.world_rank * self.config.NUM_PROCESSES
        )
        self.config.freeze()

        random.seed(self.config.TASK_CONFIG.SEED)
        np.random.seed(self.config.TASK_CONFIG.SEED)
        torch.manual_seed(self.config.TASK_CONFIG.SEED)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")


        #####################################################################
        ## build distrib NavSLAMRLEnv environment
        #####################################################################
        print("#############################################################")
        print("## build distrib NavSLAMRLEnv environment")
        print("#############################################################")
        self.envs = construct_envs(
            self.config, get_env_class(self.config.ENV_NAME)
        )
        observations = self.envs.reset()
        print("*************************** observations len:", len(observations))

        # semantic process
        for i in range(len(observations)):
            observations[i]["semantic"] = observations[i]["semantic"].astype(np.int32)
            se = list(set(observations[i]["semantic"].ravel()))
            print(se)
        # print("*************************** observations type:", observations)
        # print("*************************** observations type:", observations[0]["map_sum"].shape) # 480*480*23
        # print("*************************** observations curr_pose:", observations[0]["curr_pose"]) # []

        batch = batch_obs(observations, device=self.device)
        print("*************************** batch len:", len(batch))
        # print("*************************** batch:", batch)

        # print("************************************* current_episodes:", (self.envs.current_episodes()))

        #####################################################################
        ## init actor_critic agent
        #####################################################################  
        print("#############################################################")
        print("## init actor_critic agent")
        print("#############################################################")
        self.map_w = observations[0]["map_sum"].shape[0]
        self.map_h = observations[0]["map_sum"].shape[1]
        # print("map_: ", observations[0]["curr_pose"].shape)


        ppo_cfg = self.config.RL.PPO
        if (
            not os.path.isdir(self.config.CHECKPOINT_FOLDER)
            and self.world_rank == 0
        ):
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(observations, ppo_cfg)

        self.agent.init_distributed(find_unused_params=True)

        if self.world_rank == 0:
            logger.info(
                "agent number of trainable parameters: {}".format(
                    sum(
                        param.numel()
                        for param in self.agent.parameters()
                        if param.requires_grad
                    )
                )
            )

        #####################################################################
        ## init Global Rollout Storage
        #####################################################################  
        print("#############################################################")
        print("## init Global Rollout Storage")
        print("#############################################################") 
        self.num_each_global_step = self.config.RL.SLAMDDPPO.num_each_global_step
        rollouts = GlobalRolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            self.obs_space,
            self.g_action_space,
        )
        rollouts.to(self.device)

        print('rollouts type:', type(rollouts))
        print('--------------------------')
        # for k in rollouts.keys():
        # print("rollouts: {0}".format(rollouts.observations))

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        with torch.no_grad():
            step_observation = {
                k: v[rollouts.step] for k, v in rollouts.observations.items()
            }
    
            _, actions, _, = self.actor_critic.act(
                step_observation,
                rollouts.prev_g_actions[0],
                rollouts.masks[0],
            )

        self.global_goals = [[int(action[0].item() * self.map_w), 
                            int(action[1].item() * self.map_h)]
                            for action in actions]

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(
            self.envs.num_envs, 1, device=self.device
        )
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1, device=self.device),
            reward=torch.zeros(self.envs.num_envs, 1, device=self.device),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size)
        )

        print("*************************** current_episode_reward:", current_episode_reward)
        print("*************************** running_episode_stats:", running_episode_stats)
        # print("*************************** window_episode_stats:", window_episode_stats)


        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        start_update = 0
        prev_time = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        # interrupted_state = load_interrupted_state("/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth")
        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"]
            )
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_update = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        deif = {}
        with (
            TensorboardWriter(
                self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs
            )
            if self.world_rank == 0
            else contextlib.suppress()
        ) as writer:
            for update in range(start_update, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES
                    )
                # print("************************************* current_episodes:", type(self.envs.count_episodes()))
                
                # print(EXIT.is_set())
                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set() and self.world_rank == 0:
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            ),
                            "/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth"
                        )
                    print("********************EXIT*********************")

                    requeue_job()
                    return

                count_steps_delta = 0
                self.agent.eval()
                for step in range(ppo_cfg.num_steps):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_global_rollout_step(
                        rollouts, current_episode_reward, running_episode_stats
                    )
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps_delta += delta_steps

                    # print("************************************* current_episodes:")

                    for i in range(len(self.envs.current_episodes())):
                        # print(" ", self.envs.current_episodes()[i].episode_id," ", self.envs.current_episodes()[i].scene_id," ", self.envs.current_episodes()[i].object_category)
                        if self.envs.current_episodes()[i].scene_id not in deif:
                            deif[self.envs.current_episodes()[i].scene_id]=[int(self.envs.current_episodes()[i].episode_id)]
                        else:
                            deif[self.envs.current_episodes()[i].scene_id].append(int(self.envs.current_episodes()[i].episode_id))


                    # This is where the preemption of workers happens.  If a
                    # worker detects it will be a straggler, it preempts itself!
                    if (
                        step
                        >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
                    ) and int(num_rollouts_done_store.get("num_done")) > (
                        self.config.RL.DDPPO.sync_frac * self.world_size
                    ):
                        break

                num_rollouts_done_store.add("num_done", 1)

                self.agent.train()
                if self._static_encoder:
                    self._encoder.eval()

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time

                stats_ordering = list(sorted(running_episode_stats.keys()))
                stats = torch.stack(
                    [running_episode_stats[k] for k in stats_ordering], 0
                )
                distrib.all_reduce(stats)

                for i, k in enumerate(stats_ordering):
                    window_episode_stats[k].append(stats[i].clone())

                stats = torch.tensor(
                    [value_loss, action_loss, count_steps_delta],
                    device=self.device,
                )
                distrib.all_reduce(stats)
                count_steps += stats[2].item()

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                    losses = [
                        stats[0].item() / self.world_size,
                        stats[1].item() / self.world_size,
                    ]
                    deltas = {
                        k: (
                            (v[-1] - v[0]).sum().item()
                            if len(v) > 1
                            else v[0].sum().item()
                        )
                        for k, v in window_episode_stats.items()
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    writer.add_scalar(
                        "reward",
                        deltas["reward"] / deltas["count"],
                        count_steps,
                    )

                    # Check to see if there are any metrics
                    # that haven't been logged yet
                    metrics = {
                        k: v / deltas["count"]
                        for k, v in deltas.items()
                        if k not in {"reward", "count"}
                    }
                    if len(metrics) > 0:
                        writer.add_scalars("metrics", metrics, count_steps)

                    writer.add_scalars(
                        "losses",
                        {k: l for l, k in zip(losses, ["value", "policy"])},
                        count_steps,
                    )

                    # log stats
                    if update > 0 and update % self.config.LOG_INTERVAL == 0:
                        logger.info(
                            "update: {}\tfps: {:.3f}\t".format(
                                update,
                                count_steps
                                / ((time.time() - t_start) + prev_time),
                            )
                        )

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(
                                update, env_time, pth_time, count_steps
                            )
                        )
                        logger.info(
                            "Average window size: {}  {}".format(
                                len(window_episode_stats["count"]),
                                "  ".join(
                                    "{}: {:.3f}".format(k, v / deltas["count"])
                                    for k, v in deltas.items()
                                    if k != "count"
                                ),
                            )
                        )

                        # for k in deif:
                        #     deif[k] = list(set(deif[k]))
                        #     deif[k].sort()
                        #     print("deif: k", k, " : ", deif[k])

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"ckpt.{count_checkpoints}.pth",
                            dict(step=count_steps),
                        )
                        print('=' * 20 + 'Save Model' + '=' * 20)
                        logger.info(
                            "Save Model : {}".format(count_checkpoints)
                        )
                        count_checkpoints += 1

            self.envs.close()
Exemple #9
0
class PPOAgent(BaseAgent):
    actor: nn.Module
    critic: nn.Module
    same_body: float = False

    def __post_init__(self):
        move_to([self.actor, self.critic], device=cfg.alg.device)
        if cfg.alg.vf_loss_type == 'mse':
            self.val_loss_criterion = nn.MSELoss().to(cfg.alg.device)
        elif cfg.alg.vf_loss_type == 'smoothl1':
            self.val_loss_criterion = nn.SmoothL1Loss().to(cfg.alg.device)
        else:
            raise TypeError(
                f'Unknown value loss type: {cfg.alg.vf_loss_type}!')
        all_params = list(self.actor.parameters()) + list(
            self.critic.parameters())
        # keep unique elements only. The following code works for python >=3.7
        # for earlier version of python, u need to use OrderedDict
        self.all_params = dict.fromkeys(all_params).keys()
        if (cfg.alg.linear_decay_lr or cfg.alg.linear_decay_clip_range) and \
                cfg.alg.max_steps > cfg.alg.max_decay_steps:
            logger.warning(
                'max_steps should not be greater than max_decay_steps.')
            cfg.alg.max_decay_steps = int(cfg.alg.max_steps * 1.5)
            logger.warning(
                f'Resetting max_decay_steps to {cfg.alg.max_decay_steps}!')
        total_epochs = int(
            np.ceil(cfg.alg.max_decay_steps /
                    (cfg.alg.num_envs * cfg.alg.episode_steps)))
        if cfg.alg.linear_decay_clip_range:
            self.clip_range_decay_rate = cfg.alg.clip_range / float(
                total_epochs)

        p_lr_lambda = partial(linear_decay_percent, total_epochs=total_epochs)
        optim_args = dict(lr=cfg.alg.policy_lr,
                          weight_decay=cfg.alg.weight_decay)
        if not cfg.alg.sgd:
            optim_args['amsgrad'] = cfg.alg.use_amsgrad
            optim_func = optim.Adam
        else:
            optim_args['nesterov'] = True if cfg.alg.momentum > 0 else False
            optim_args['momentum'] = cfg.alg.momentum
            optim_func = optim.SGD
        if self.same_body:
            optim_args['params'] = self.all_params
        else:
            optim_args['params'] = [{
                'params': self.actor.parameters(),
                'lr': cfg.alg.policy_lr
            }, {
                'params': self.critic.parameters(),
                'lr': cfg.alg.value_lr
            }]

        self.optimizer = optim_func(**optim_args)

        if self.same_body:
            self.lr_scheduler = LambdaLR(optimizer=self.optimizer,
                                         lr_lambda=[p_lr_lambda])
        else:
            v_lr_lambda = partial(linear_decay_percent,
                                  total_epochs=total_epochs)
            self.lr_scheduler = LambdaLR(optimizer=self.optimizer,
                                         lr_lambda=[p_lr_lambda, v_lr_lambda])

    @torch.no_grad()
    def get_action(self, ob, sample=True, *args, **kwargs):
        self.eval_mode()
        if type(ob) is dict:
            t_ob = {
                key: torch_float(ob[key], device=cfg.alg.device)
                for key in ob
            }
        else:
            t_ob = torch_float(ob, device=cfg.alg.device)

        act_dist, val = self.get_act_val(t_ob)
        action = action_from_dist(act_dist, sample=sample)
        log_prob = action_log_prob(action, act_dist)
        entropy = action_entropy(act_dist, log_prob)
        action_info = dict(log_prob=torch_to_np(log_prob),
                           entropy=torch_to_np(entropy),
                           val=torch_to_np(val))
        return torch_to_np(action), action_info

    def get_act_val(self, ob, *args, **kwargs):
        if type(ob) is dict:
            ob = {
                key: torch_float(ob[key], device=cfg.alg.device)
                for key in ob
            }
        else:
            ob = torch_float(ob, device=cfg.alg.device)

        act_dist, body_out = self.actor(ob)
        if self.same_body:
            val, body_out = self.critic(body_x=body_out)
        else:
            val, body_out = self.critic(x=ob)
        val = val.squeeze(-1)
        return act_dist, val

    @torch.no_grad()
    def get_val(self, ob, *args, **kwargs):
        self.eval_mode()

        if type(ob) is dict:
            ob = {
                key: torch_float(ob[key], device=cfg.alg.device)
                for key in ob
            }
        else:
            ob = torch_float(ob, device=cfg.alg.device)

        val, body_out = self.critic(x=ob)
        val = val.squeeze(-1)
        return val

    def optimize(self, data, *args, **kwargs):
        pre_res = self.optim_preprocess(data)
        processed_data = pre_res
        processed_data['entropy'] = torch.mean(processed_data['entropy'])
        loss_res = self.cal_loss(**processed_data)
        loss, pg_loss, vf_loss, ratio = loss_res
        self.optimizer.zero_grad()
        loss.backward()

        grad_norm = clip_grad(self.all_params, cfg.alg.max_grad_norm)
        self.optimizer.step()
        with torch.no_grad():
            approx_kl = 0.5 * torch.mean(
                torch.pow(
                    processed_data['old_log_prob'] -
                    processed_data['log_prob'], 2))
            clip_frac = np.mean(
                np.abs(torch_to_np(ratio) - 1.0) > cfg.alg.clip_range)
        optim_info = dict(pg_loss=pg_loss.item(),
                          vf_loss=vf_loss.item(),
                          total_loss=loss.item(),
                          entropy=processed_data['entropy'].item(),
                          approx_kl=approx_kl.item(),
                          clip_frac=clip_frac)
        optim_info['grad_norm'] = grad_norm
        return optim_info

    def optim_preprocess(self, data):
        self.train_mode()
        for key, val in data.items():
            data[key] = torch_float(val, device=cfg.alg.device)
        ob = data['ob']
        state = data['state']
        action = data['action']
        ret = data['ret']
        adv = data['adv']
        old_log_prob = data['log_prob']
        old_val = data['val']

        act_dist, val = self.get_act_val({"ob": ob, "state": state})
        log_prob = action_log_prob(action, act_dist)
        entropy = action_entropy(act_dist, log_prob)
        if not all([x.ndim == 1 for x in [val, entropy, log_prob]]):
            raise ValueError('val, entropy, log_prob should be 1-dim!')
        processed_data = dict(val=val,
                              old_val=old_val,
                              ret=ret,
                              log_prob=log_prob,
                              old_log_prob=old_log_prob,
                              adv=adv,
                              entropy=entropy)
        return processed_data

    def cal_loss(self, val, old_val, ret, log_prob, old_log_prob, adv,
                 entropy):
        vf_loss = self.cal_val_loss(val=val, old_val=old_val, ret=ret)
        ratio = torch.exp(log_prob - old_log_prob)
        surr1 = adv * ratio
        surr2 = adv * torch.clamp(ratio, 1 - cfg.alg.clip_range,
                                  1 + cfg.alg.clip_range)
        pg_loss = -torch.mean(torch.min(surr1, surr2))

        loss = pg_loss - entropy * cfg.alg.ent_coef + \
               vf_loss * cfg.alg.vf_coef
        return loss, pg_loss, vf_loss, ratio

    def cal_val_loss(self, val, old_val, ret):
        if cfg.alg.clip_vf_loss:
            clipped_val = old_val + torch.clamp(
                val - old_val, -cfg.alg.clip_range, cfg.alg.clip_range)
            vf_loss1 = torch.pow(val - ret, 2)
            vf_loss2 = torch.pow(clipped_val - ret, 2)
            vf_loss = 0.5 * torch.mean(torch.max(vf_loss1, vf_loss2))
        else:
            # val = torch.squeeze(val)
            vf_loss = 0.5 * self.val_loss_criterion(val, ret)
        return vf_loss

    def train_mode(self):
        self.actor.train()
        self.critic.train()

    def eval_mode(self):
        self.actor.eval()
        self.critic.eval()

    def decay_lr(self):
        self.lr_scheduler.step()

    def get_lr(self):
        cur_lr = self.lr_scheduler.get_lr()
        lrs = {'policy_lr': cur_lr[0]}
        if len(cur_lr) > 1:
            lrs['value_lr'] = cur_lr[1]
        return lrs

    def decay_clip_range(self):
        cfg.alg.clip_range -= self.clip_range_decay_rate

    def save_model(self, is_best=False, step=None):
        self.save_env(cfg.alg.model_dir)
        data_to_save = {
            'step': step,
            'actor_state_dict': self.actor.state_dict(),
            'critic_state_dict': self.critic.state_dict(),
            'optim_state_dict': self.optimizer.state_dict(),
            'lr_scheduler_state_dict': self.lr_scheduler.state_dict()
        }

        if cfg.alg.linear_decay_clip_range:
            data_to_save['clip_range'] = cfg.alg.clip_range
            data_to_save['clip_range_decay_rate'] = self.clip_range_decay_rate
        save_model(data_to_save, cfg.alg, is_best=is_best, step=step)

    def load_model(self, step=None, pretrain_model=None):
        self.load_env(cfg.alg.model_dir)
        ckpt_data = load_ckpt_data(cfg.alg,
                                   step=step,
                                   pretrain_model=pretrain_model)
        load_state_dict(self.actor, ckpt_data['actor_state_dict'])
        load_state_dict(self.critic, ckpt_data['critic_state_dict'])
        if pretrain_model is not None:
            return
        self.optimizer.load_state_dict(ckpt_data['optim_state_dict'])
        self.lr_scheduler.load_state_dict(ckpt_data['lr_scheduler_state_dict'])
        if cfg.alg.linear_decay_clip_range:
            self.clip_range_decay_rate = ckpt_data['clip_range_decay_rate']
            cfg.alg.clip_range = ckpt_data['clip_range']
        return ckpt_data['step']

    def print_param_grad_status(self):
        logger.info('Requires Grad?')
        logger.info('================== Actor ================== ')
        for name, param in self.actor.named_parameters():
            print(f'{name}: {param.requires_grad}')
        logger.info('================== Critic ================== ')
        for name, param in self.critic.named_parameters():
            print(f'{name}: {param.requires_grad}')
Exemple #10
0
class Trainer(object):
    def __init__(self, args):
        self.name = args.name
        self.max_epoch = args.max_epoch
        self.lr = args.lr
        self.weight_decay = args.weight_decay
        self.log = args.log
        self.out_every = args.out_every
        self.pos_w = args.pos_w,
        self.LAMBDA = args.LAMBDA
        if args.cuda_dev:
            torch.cuda.set_device(args.cuda_dev[0])
            self.cuda_dev = f'cuda:{args.cuda_dev[0]}'
            self.device = 'cuda'
        else:
            self.cuda_dev = None
            self.device = 'cpu'
        print(f'Using {self.device}')
        self.z_dim = args.z_dim
        self.batch_size = args.batch_size
        self.start_save = args.start_save
        self.start_epoch = args.start_epoch
        self.ckpt_dir = os.path.join(args.ckpt_dir, self.name)
        if args.data_type == 'simATAC':
            self.dataset = SimATAC(args.setting,
                                   args.signal,
                                   args.frags,
                                   args.bin_size,
                                   conv=args.conv)
        elif args.data_type == 'atlas':
            self.dataset = MouseAtlas(cutoff=CUT_OFF)
        elif args.data_type == 'pbmc':
            self.dataset = PBMC()
        elif args.data_type == 'mergeSim':
            if args.num:
                self.dataset = MergeSim(args.num)
            else:
                self.dataset = MergeSim()
        else:
            raise Exception(f'Dataset {args.data_type} does not exist!')
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=3 * len(args.cuda_dev),
                                     pin_memory=True,
                                     drop_last=True)
        input_dim = self.dataset.padto
        if args.model_type == 'inv':
            if args.sample_batch:
                self.de_batch = True
                self.vae = VAE2(input_dim, args.z_dim, batch=True)
            else:
                self.de_batch = False
                self.vae = VAE2(input_dim, args.z_dim)
            self.vaeI = VAEInv(self.vae)
            self.model = nn.DataParallel(self.vaeI, device_ids=args.cuda_dev)
        else:
            raise Exception(f'Model type {args.model_type} does not exist!')
        self.model_type = args.model_type
        if args.load_ckpt:
            self.load_ckpt(args.load_ckpt)
            # if os.path.isfile(args.load_ckpt):
            #     print('Loading ' + args.load_ckpt)
            #     if self.cuda_dev:
            #         self.model.module.load_state_dict(torch.load(args.load_ckpt, map_location=self.cuda_dev))
            #     else:
            #         self.model.module.load_state_dict(torch.load(args.load_ckpt, map_location='cpu'))
            #     print('Finished Loading ckpt...')
            # else:
            #     raise Exception(args.load_ckpt + "\nckpt does not exist!")
        self.model.to(self.device)
        self.optim = optim.Adam(self.model.parameters(),
                                lr=self.lr,
                                weight_decay=self.weight_decay)
        self.cycle = CYCLE * self.dataset.__len__() // self.batch_size // len(
            args.cuda_dev)
        lr_lmd = lambda epoch: 0.995**epoch
        self.le_scdlr = LambdaLR(self.optim, lr_lambda=lr_lmd)
        self.le_scdlr.last_epoch = self.start_epoch - 1

    def load_ckpt(self, ckpt_pth):
        if os.path.isfile(ckpt_pth):
            print('Loading ' + ckpt_pth)
            if self.cuda_dev:
                self.model.module.load_state_dict(
                    torch.load(ckpt_pth, map_location=self.cuda_dev))
            else:
                self.model.module.load_state_dict(
                    torch.load(ckpt_pth, map_location='cpu'))
            print('Finished Loading ckpt...')
        else:
            raise Exception(ckpt_pth + "\nckpt does not exist!")

    def warm_up(self):
        if not os.path.exists(self.ckpt_dir):
            print(f'Making dir {self.ckpt_dir}')
            os.makedirs(self.ckpt_dir)
        self.model.train()
        self.pbar = tqdm(total=WARM_UP)
        total_iter = 0
        for step in range(WARM_UP):
            for x, s, l in self.dataloader:
                l = l.unsqueeze(1).float().to(self.device).log()
                l = (l - self.dataset.d_mean) / self.dataset.d_std
                total_iter += 1
                x = x.float().to(self.device)
                if self.model_type == 'adv':
                    _, _, _, rec, _ = self.model(x, l)
                elif self.model_type == 'inv':
                    if self.de_batch:
                        s = s.unsqueeze(1).float().to(self.device)
                        _, _, _, rec = self.model(x, l, b=s)
                    else:
                        _, _, _, rec = self.model(x, l)
                else:
                    _, _, _, rec = self.model(x)
                pos_weight = torch.Tensor([self.pos_w]).to(self.device)
                bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
                # rec_loss = focal(rec.view(-1), x.view(-1).long())
                rec_loss = bce(rec, x)
                self.optim.zero_grad()
                rec_loss.backward()
                self.optim.step()
                # if total_iter%50 == 0:
                #     self.pbar.write(f'[{total_iter}] vae_recon_loss:{rec_loss.item()}')
            self.pbar.update(1)
        torch.save(self.model.module.state_dict(),
                   os.path.join(self.ckpt_dir, 'warmup.pt'))
        self.pbar.write("[Warmup Finished]")
        self.pbar.close()

    def rec_all(self, batch_size=1, same_depth=False):
        dataloader = DataLoader(self.dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=0,
                                pin_memory=True,
                                drop_last=False)
        labels = []
        self.model.eval()
        for i, dp in tqdm(enumerate(dataloader)):
            x, l, d = dp
            x = x.float().to(self.device)
            labels = labels + l
            if same_depth:
                d = d.unsqueeze(1).float().to(self.device).log()
                # same_depth = (same_depth - self.dataset.d_mean) / self.dataset.d_std
                d = (torch.ones_like(d) * same_depth).log()
            else:
                d = d.unsqueeze(1).float().to(self.device).log()
            with torch.no_grad():
                _, _, _, rec = self.model.forward(x, d)
                # rec = torch.sigmoid(rec).cpu()
                rec = rec.cpu()
                if i == 0:
                    out = rec
                else:
                    out = torch.cat((out, rec))
        return out, labels

    def rec_batch(self, batch_size=1, same_depth=False):
        dataloader = DataLoader(self.dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=0,
                                pin_memory=True,
                                drop_last=False)
        labels = []
        self.model.eval()
        for i, dp in tqdm(enumerate(dataloader)):
            x, l, d = dp
            x = x.float().to(self.device)
            labels = labels + l
            if same_depth:
                d = d.unsqueeze(1).float().to(self.device).log()
                # same_depth = (same_depth - self.dataset.d_mean) / self.dataset.d_std
                d = (torch.ones_like(d) * same_depth).log()

            else:
                d = d.unsqueeze(1).float().to(self.device).log()
            b = torch.zeros_like(d).float().to(self.device)
            with torch.no_grad():
                _, _, _, rec = self.model.forward(x, d, b)
                # rec = torch.sigmoid(rec).cpu()
                rec = rec.cpu()
                if i == 0:
                    out = rec
                else:
                    out = torch.cat((out, rec))
        return out, labels

    def inv_train(self):
        if not os.path.exists(self.ckpt_dir):
            print(f'Making dir {self.ckpt_dir}')
            os.makedirs(self.ckpt_dir)
        self.model.train()
        kl_list, rec_list = [], []
        print('Inv Training started')
        self.pbar = tqdm(total=self.max_epoch - self.start_epoch)
        total_iter = (self.start_epoch -
                      1) * self.dataset.__len__() // self.batch_size + 1
        for epoch in range(self.start_epoch,
                           self.start_epoch + self.max_epoch):
            epoch_kl, epoch_rec = [], []
            kl_w = np.min([
                2 * (total_iter - (total_iter // self.cycle) * self.cycle) /
                self.cycle, 1
            ])
            for x1, s1, l1 in self.dataloader:
                x1 = x1.float().to(self.device)
                l1 = l1.log()
                l1 = (l1 - self.dataset.d_mean) / self.dataset.d_std
                l1 = l1.unsqueeze(1).float().to(self.device)
                if self.de_batch:
                    s1 = s1.unsqueeze(1).float().to(self.device)
                    z_mean, z_log_var, _, rec = self.model(x1, l1, b=s1)
                else:
                    z_mean, z_log_var, _, rec = self.model(x1, l1)
                mean = torch.zeros_like(z_mean)
                var = torch.ones_like(z_log_var)
                kld_z = kl(Normal(z_mean,
                                  torch.exp(z_log_var).sqrt()),
                           Normal(mean, var)).sum()
                pos_weight = torch.Tensor([self.pos_w]).to(self.device)
                bce = F.binary_cross_entropy_with_logits(rec,
                                                         x1,
                                                         weight=pos_weight,
                                                         reduction='sum')
                rec_loss = bce
                m_kld = apprx_kl(
                    z_mean,
                    torch.exp(z_log_var).sqrt()).mean() - 0.5 * self.z_dim
                loss = kld_z * kl_w + (
                    1 + self.LAMBDA) * rec_loss + m_kld * kl_w * self.LAMBDA
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                epoch_kl.append(kld_z.item())
                epoch_rec.append(bce.item())
                total_iter += 1
            kl_list.append(np.mean(epoch_kl))
            rec_list.append(np.mean(epoch_rec))
            self.pbar.update(1)
            self.le_scdlr.step()
            # self.pbar.write(f'[{epoch}], iter {total_iter}')
            if epoch % self.out_every == 0:
                if epoch > self.start_save:
                    torch.save(self.model.module.state_dict(),
                               os.path.join(self.ckpt_dir, f'{epoch}.pt'))
                logdata = {
                    'iter': list(range(self.start_epoch, epoch + 1)),
                    'kl': kl_list,
                    'bce': rec_list
                }
                df = pd.DataFrame(logdata)
                df.to_csv(os.path.join(self.ckpt_dir, 'inv' + self.log),
                          index=False)
        self.pbar.write("[Inv Training Finished]")
        self.pbar.close()

    def encode_adv(self, batch_size=1000):
        dataloader = DataLoader(self.dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=0,
                                pin_memory=True,
                                drop_last=False)
        labels = []
        latent = torch.zeros(self.dataset.__len__(), self.z_dim)
        depth = torch.zeros(self.dataset.__len__())
        self.model.eval()
        for i, dp in tqdm(enumerate(dataloader)):
            x, l, d = dp
            x = x.float().to(self.device)
            if self.de_batch:
                labels = labels + list(l)
            else:
                labels = labels + l
            depth[i * batch_size:(i + 1) * batch_size] = d
            d = d.log()
            d = (d - self.dataset.d_mean) / self.dataset.d_std
            d = d.unsqueeze(1).float().to(self.device)
            with torch.no_grad():
                z_mean, _ = self.model.forward(x, d, no_rec=True)
                # z_mean, _, _, _ = self.model(x)
                latent[i * batch_size:(i + 1) * batch_size] = z_mean.cpu()
        return latent, labels, depth
Exemple #11
0
def run_COPT(game, num_iter=5000, lr=0.5, seed=1234, biased=False,
             shuffling=False, lr_schedule=None,
             hamiltonian_coeff=10, **kwargs):
    config = Config(dict(mode="consensus_opt", num_iter=num_iter, lr=lr,
                         seed=seed, hamiltonian_coeff=hamiltonian_coeff,
                         shuffling=shuffling))
    torch.manual_seed(seed)
    game.reset()
    sgd = optim.SGD(game.parameters(), lr=lr)
    if lr_schedule is not None:
        lr_schedule = SchedulerLR(lr_schedule)
        scheduler = LambdaLR(sgd, lr_schedule)
    else:
        scheduler = LambdaLR(sgd, lambda k: 1.)
    logger = defaultdict(list)

    if kwargs["output"] is not None:
        path = os.path.join(kwargs["output"], config.name, str(seed))
        config["path"] = path
        if not os.path.exists(path):
            os.makedirs(os.path.join(path, "results"))

        config["name"] = config.name

        with open(os.path.join(path, "config.json"), "w") as f:
            json.dump(config, f, default=lambda x: "non-serializable")

    if shuffling:
        game.shuffle()

    n_samples = 0
    start_time = time.time()
    for i in tqdm(range(num_iter)):
        index1 = game.sample()
        index2 = game.sample()
        if biased is True:
            grad1 = game.compute_grad(index1)
            grad2 = grad1
            hamiltonian = compute_hamiltonian(grad1)
            n_samples += 1

        elif biased == "copt":
            grad1 = game.compute_grad(torch.cat([index1, index2]))
            grad2 = grad1
            hamiltonian = compute_hamiltonian(grad1)
            n_samples += 2

        elif biased is False:
            grad1 = game.compute_grad(index1)
            grad2 = game.compute_grad(index2)
            hamiltonian = compute_hamiltonian(grad1)
            n_samples += 2

        else:
            raise ValueError()

        grad_H = autograd.grad(hamiltonian, game.parameters())
        for p, g1, g2, gH in zip(game.parameters(), grad1, grad2, grad_H):
            p.grad = 0.5*(g1+g2) + hamiltonian_coeff*gH
        sgd.step()
        scheduler.step()

        metrics = game.compute_metrics()
        for key, value in metrics.items():
            logger[key].append(value)
        logger["lr"].append(scheduler.get_last_lr())
        logger["num_samples"].append(n_samples)
        logger["time"].append(time.time()-start_time)

        if i % 10000 == 0:
            with open(os.path.join(path, "results.json"), "w") as f:
                json.dump(logger, f)

    return logger, config
Exemple #12
0
def run_SHGD(game, num_iter=5000, lr=None, seed=1234, save_params=False,
             biased=False, shuffling=False, lr_schedule=None, **kwargs):
    if lr is None:
        lr = float(1/(2*game.L))
    if lr_schedule == "optimal":
        lr_schedule = int(4*(game.L/game.mu))

    config = Config(dict(mode="shgd", num_iter=num_iter, lr=lr, seed=seed,
                         biased=biased, lr_schedule=lr_schedule,
                         shuffling=shuffling))
    torch.manual_seed(seed)
    game.reset()

    sgd = optim.SGD(game.parameters(), lr=lr)
    if lr_schedule is not None:
        lr_schedule = SchedulerLR(lr_schedule)
        scheduler = LambdaLR(sgd, lr_schedule)
    else:
        scheduler = LambdaLR(sgd, lambda k: 1.)
    logger = defaultdict(list)

    if kwargs["output"] is not None:
        path = os.path.join(kwargs["output"], config.name, str(seed))
        config["path"] = path
        if not os.path.exists(path):
            os.makedirs(os.path.join(path, "results"))

        config["name"] = config.name

        with open(os.path.join(path, "config.json"), "w") as f:
            json.dump(config, f, default=lambda x: "non-serializable")

    if shuffling:
        game.shuffle()

    n_samples = 0
    params_history = []
    start_time = time.time()
    for i in tqdm(range(num_iter)):
        index1 = game.sample()
        index2 = game.sample()
        if biased is True:
            hamiltonian = game.compute_hamiltonian(index1)
            n_samples += 1
        elif biased == "copt":
            hamiltonian = game.compute_hamiltonian(torch.cat([index1, index2]))
            n_samples += 2
        elif biased is False:
            hamiltonian = game.compute_hamiltonian(index1, index2)
            n_samples += 2
        else:
            raise ValueError()
        grad = autograd.grad(hamiltonian, game.parameters())
        for p, g in zip(game.parameters(), grad):
            p.grad = g
        sgd.step()
        scheduler.step()

        metrics = game.compute_metrics()
        for key, value in metrics.items():
            logger[key].append(value)
        #logger["lr"].append(scheduler.get_last_lr())
        logger["num_samples"].append(n_samples)
        logger["time"].append(time.time()-start_time)

        if save_params:
            params_history.append(copy.deepcopy(game.state_dict()))

        if i % 10000 == 0:
            with open(os.path.join(path, "results.json"), "w") as f:
                json.dump(logger, f)

    logger["params"] = params_history
    return logger, config
Exemple #13
0
    x_train, x_test, y_train, y_test = get_data()
    train_dataloader = inf_data_gen(x_train, y_train, cfg.TRAIN.BATCH_SIZE)
    X_test = torch.Tensor(x_test).to(cfg.SYSTEM.DEVICE)
    Y_test = torch.Tensor(y_test).to(cfg.SYSTEM.DEVICE)

    net = Net(D=cfg.MODEL.D, W=cfg.MODEL.W)
    net.to(cfg.SYSTEM.DEVICE)

    optimizer = torch.optim.SGD(net.parameters(), lr=cfg.TRAIN.LEARNING_RATE)
    scheduler = LambdaLR(optimizer, lr_lambda=inv_root_lr)
    pbar = tqdm(train_dataloader, total=cfg.TRAIN.STEPS)
    for n_iter, (X, T) in enumerate(pbar, start=1):
        X, T = X.to(cfg.SYSTEM.DEVICE), T.to(cfg.SYSTEM.DEVICE)
        optimizer.zero_grad()
        net.train()
        train(net, X, T, optimizer, n_iter)

        if n_iter % 5000 == 0:
            net.eval()
            test(net, X_test, Y_test, n_iter)
            _WRITER.add_scalar('LR', get_lr(optimizer), n_iter)

        scheduler.step(n_iter)

        if n_iter % (cfg.TRAIN.STEPS / 10) == 0:
            analysis(net, x_train, y_train, n_iter)

        if n_iter > cfg.TRAIN.STEPS:
            break
Exemple #14
0
def main():
    global opt
    # train data loader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.batchSize,
                                               shuffle=True,
                                               num_workers=int(opt.workers))

    # create model
    model = models.VAMetric_conv()

    if opt.init_model != '':
        print('loading pretrained model from {0}'.format(opt.init_model))
        model.load_state_dict(torch.load(opt.init_model))

    # Contrastive Loss
    criterion = models.conv_loss_dqy()

    if opt.cuda:
        print('shift model and criterion to GPU .. ')
        model = model.cuda()
        criterion = criterion.cuda()

    # optimizer
    # optimizer = optim.SGD(model.parameters(), lr=opt.lr,
    #                      momentum=opt.momentum,
    #                      weight_decay=opt.weight_decay)

    optimizer = optim.Adam(model.parameters(), lr=opt.lr)
    # optimizer = optim.SGD(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay, momentum=opt.momentum)
    # optimizer = optim.Adadelta(params=model.parameters(), lr=opt.lr)
    # adjust learning rate every lr_decay_epoch
    lambda_lr = lambda epoch: opt.lr_decay**(
        (epoch + 1) // opt.lr_decay_epoch)  # poly policy
    scheduler = LR_Policy(optimizer, lambda_lr)

    resume_epoch = 0

    global dis1_rec
    global dis2_rec
    global loss_rec

    loss_rec = []
    dis1_rec = []
    dis2_rec = []

    ######### to test each epoch
    parser = OptionParser()
    parser.add_option('--config',
                      type=str,
                      help="evaluation configuration",
                      default="./configs/test_config.yaml")

    (opts_test, args) = parser.parse_args()
    opts_test = Config(opts_test.config)
    test_video_dataset = VideoFeatDataset(opts_test.data_dir,
                                          opts_test.video_flist,
                                          which_feat='vfeat')
    test_audio_dataset = VideoFeatDataset(opts_test.data_dir,
                                          opts_test.audio_flist,
                                          which_feat='afeat')
    test_video_loader = torch.utils.data.DataLoader(
        test_video_dataset,
        batch_size=opts_test.batchSize,
        shuffle=False,
        num_workers=int(opts_test.workers))
    test_audio_loader = torch.utils.data.DataLoader(
        test_audio_dataset,
        batch_size=opts_test.batchSize,
        shuffle=False,
        num_workers=int(opts_test.workers))

    ########

    # another test for git
    for epoch in range(resume_epoch, opt.max_epochs):
        #################################
        # train for one epoch
        #################################
        train(train_loader, model, criterion, optimizer, epoch, opt,
              test_video_loader, test_audio_loader, opts_test)
        scheduler.step()

        ##################################
        # save checkpoints
        ##################################

        # save model every 10 epochs
        if ((epoch + 1) % opt.epoch_save) == 0:
            path_checkpoint = '{0}/{1}_state_epoch{2}.pth'.format(
                opt.checkpoint_folder, opt.prefix, epoch + 1)
            utils.save_checkpoint(model.state_dict(), path_checkpoint)

    plt.figure(1)
    plt.subplot(1, 2, 1)
    plt.plot(loss_rec)
    plt.legend('loss')
    plt.subplot(1, 2, 2)
    plt.plot(dis1_rec)
    plt.plot(dis2_rec)
    plt.legend(('distance between positives', 'distance between negatives'))
    plt.show()
    plt.savefig("./figures/conv.jpg")
Exemple #15
0
def run_training(
        source,
        target,
        dataset_root,
        net_name,
        da_method,
        max_iter,
        stop_iter,
        test_iter,
        logdir,
        run_name,
        gpu_id,
        load_workers,
        config,
        test_src: bool = False,
        use_tqdm: bool = True,
        kill_diverging: bool = False):
    dev = torch.device(f'cuda:{gpu_id}')

    if kill_diverging:
        assert test_src

    # Get config
    # Config arrives here (from BOHB or direct cli invocation) as a dictionary like
    # {'disc.dropout': 0.5, 'net.bottleneck_size_log': 9}
    # We separate it in something like
    # {'disc': {'dropout': 0.5'}, 'net': {'bottleneck_size_log': 9}}
    config = split_dict(config)
    # Disc args are not meaningful without DA
    if da_method != 'so':
        # Default disc args
        disc_args = {
            'dropout': 0.5,
            'num_fc_layers': 3,
            'hidden_size_log': 10
        }
        # Update with the ones coming from config (if any)
        disc_args.update(config.get('disc', {}))
        # Some args might be defined as log2. Replace them (bottleneck_size_log -> bottleneck_size)
        remove_log_hps(disc_args)
        # Print disc args
        print(f"Discriminator config: {disc_args}")
    # Very similar, but for the backbone
    net_args = {
        'use_bottleneck': da_method != 'so',
        'bottleneck_size_log': 9
    }
    net_args.update(config.get('net', {}))
    remove_log_hps(net_args)
    print(f"Backbone config: {net_args}")
    # Now net_args and disc_args are ready to be passed to the network constructors as **kwargs :)
    bs, lr, wd = config['base']['bs'], config['base']['lr'], config['base']['wd']

    # Load datasets and their number o classes
    dset_src_train, dset_src_test, dset_trg_train, dset_trg_test, num_classes = \
        prepare_datasets(source, target, dataset_root)

    dload_src_train = DataLoader(dset_src_train, batch_size=bs, shuffle=True, num_workers=load_workers, drop_last=True)
    dload_src_test = DataLoader(dset_src_test, batch_size=bs, shuffle=False, num_workers=load_workers)
    dload_trg_train = DataLoader(dset_trg_train, batch_size=bs, shuffle=True, num_workers=load_workers, drop_last=True)
    dload_trg_test = DataLoader(dset_trg_test, batch_size=bs, shuffle=False, num_workers=load_workers)

    print(f"Source samples: {len(dset_src_train)}")
    print(f"Target samples: {len(dset_trg_train)}")
    print(f"Num classes: {num_classes}")

    # Build network
    base_network = resnet.ResNetFc(
        resnet_name=net_name,
        num_classes=num_classes,
        plug_position=7,
        **net_args
    ).to(dev)
    params = base_network.get_parameters(lr, wd)
    # Source only has no secondary branches
    if da_method != 'so':
        disc_classes = {
            # ( -> confusion matrix)
            'alda': num_classes,
            # ( -> binary domain classifier)
            'dann': 2
        }[da_method]
        discriminator = resnet.Discriminator(in_feature=base_network.output_size(), num_classes=disc_classes,
                                             **disc_args).to(dev)
        params += discriminator.get_parameters(lr, wd)

    # Define optimizer
    optimizer = opt.SGD(
        params=params,
        lr=lr,
        momentum=0.9,
        weight_decay=wd,
        nesterov=True
    )

    # Lr policy
    lr_schedule = LambdaLR(optimizer, lr_lambda=lambda it: (1 + 0.001 * it) ** (-0.75))

    # Logger
    writer = Logger(logdir=logdir, run_name=run_name, use_tb=True, use_tqdm=use_tqdm)

    # Classification loss
    ce_loss = nn.CrossEntropyLoss()

    # Train loop
    len_train_source = len(dload_src_train)
    len_train_target = len(dload_trg_train)
    lambda_val = 0.

    # We store all the metrics here
    metrics = []

    all_pseudolabels = []

    with writer.progress(total=stop_iter, desc="Training") as pb:
        for i in range(stop_iter):
            if (i + 1) % test_iter == 0:
                print(f"Iteration: {i + 1} / {stop_iter} (max: {max_iter})")
                print("Testing...")
                base_network.train(False)
                # This dict contains metric-name -> value pairs for the current epoch
                new_metrics = {}
                if test_src:
                    test_result, _, src_test_feats = test(dload_src_test, base_network, device=dev)
                    # Print accuracy
                    print("Source accuracy: {:.3f} %".format(test_result['accuracy'] * 100))
                    # Add the source metrics to the dict (with the source_ prefix)
                    new_metrics.update({f'source_{k}': v for k, v in test_result.items()})

                test_result, epoch_pseudolabels, _ = test(dload_trg_test, base_network, device=dev,
                                                          source_feats=src_test_feats)
                all_pseudolabels.append(epoch_pseudolabels)
                print(f"Target accuracy: {test_result['accuracy'] * 100:.3f} %")

                writer.add_scalar('train/base_lr', lr_schedule.get_last_lr()[0], i)
                writer.add_scalar('train/lambda', lambda_val, i)

                new_metrics.update({f'target_{k}': v for k, v in test_result.items()})

                # Add all the new metrics to tensorboard logs
                add_scalars(writer, new_metrics, global_step=i, prefix='test/')
                # Add a column with iteration number
                new_metrics.update({'iter': i})
                # Concatenate to older epoch metrics
                metrics.append(new_metrics)

                # Kill this training if source loss goes too high
                if kill_diverging and new_metrics['source_class_loss'] > SOURCE_LOSS_THRESHOLD:
                    if len(metrics) > 0 and new_metrics['source_class_loss'] > metrics[-1]['source_class_loss']:
                        print(f"Increasing source_class_loss exceeds maximum allowed source loss ({new_metrics['source_class_loss']} > {SOURCE_LOSS_THRESHOLD})")
                        break

            # Train one iteration
            base_network.train(True)
            if da_method != 'so':
                discriminator.train(True)

            optimizer.zero_grad()

            # Reset data loops if required
            if i % len_train_source == 0:
                iter_source = iter(dload_src_train)
            if i % len_train_target == 0:
                iter_target = iter(dload_trg_train)

            # Load source
            inputs_source, labels_source = iter_source.next()
            inputs_source, labels_source = map_to_device(dev, (inputs_source, labels_source))

            # Compute source features and classification output
            outputs_source, features_source = base_network(inputs_source)

            # Classification loss
            classifier_loss = ce_loss(outputs_source, labels_source)

            # Actual DA part
            if da_method != 'so':
                # Load target samples without target labels
                inputs_target, _ = iter_target.next()
                inputs_target = inputs_target.to(dev)

                # Compute target features and classification output
                outputs_target, features_target = base_network(inputs_target)

                # Source and target features
                features = torch.cat((features_source, features_target), dim=0)
                # Source and target classification outputs (-> softmax)
                outputs = torch.cat((outputs_source, outputs_target), dim=0)
                softmax_out = nn.Softmax(dim=1)(outputs)

                # CORE
                if da_method == 'dann':
                    p = float(i / max_iter)
                    lambda_val = 2. / (1 + np.exp(-10 * p)) - 1
                    ad_out = discriminator(features, lambda_val)
                    adv_loss = loss.DANN_loss(ad_out)
                    transfer_loss = adv_loss
                    if (i + 1) % test_iter == 0:
                        print("Transfer loss: {:.3f}".format(transfer_loss.item()))
                elif da_method == 'alda':
                    p = float(i / max_iter)
                    lambda_val = 2. / (1 + np.exp(-10 * p)) - 1
                    ad_out = discriminator(features, lambda_val)
                    adv_loss, reg_loss, correct_loss = loss.ALDA_loss(ad_out, labels_source, softmax_out, threshold=0.9)

                    transfer_loss = adv_loss + lambda_val * correct_loss
                    if (i + 1) % test_iter == 0:
                        print("Transfer loss: {:.3f}, reg loss  {:.3f}%".format(transfer_loss.item(),
                                                                                reg_loss.item()))
                    # Backpropagate reg_loss only through the discriminator
                    with base_network.freeze():
                        reg_loss.backward(retain_graph=True)
                # END CORE
            else:
                transfer_loss = 0

            total_loss = classifier_loss + config['base']['weight_da'] * transfer_loss
            total_loss.backward()

            optimizer.step()
            lr_schedule.step()

            if (i + 1) % test_iter == 0 and da_method != 'so':
                writer.add_scalar('train/transfer_loss', transfer_loss.item(), i)
            pb.update(1)

    # Convert list of dicts to dataframe containing metrics
    metrics = pd.DataFrame(metrics)

    # Compute global-pseudolabel accuracy
    all_pseudolabels = np.array(all_pseudolabels)
    global_pseudolabels = compute_time_consistent_pseudolabels(all_pseudolabels, num_classes)
    pseudolabel_acc = np.equal(all_pseudolabels, global_pseudolabels).sum(axis=1) / global_pseudolabels.shape[0]
    # Add it to the metrics dataframe
    metrics['target_pseudolabels'] = pseudolabel_acc

    # Save the metrics
    with open(os.path.join(logdir, run_name, "metrics.pkl"), "wb") as fp:
        pickle.dump(metrics, fp)

    # Log global pseudolabel accuracy to tensorboard
    for i in range(len(all_pseudolabels)):
        writer.add_scalar('test/target_pseudolabels', float(pseudolabel_acc[i]), i * test_iter)

    return metrics
Exemple #16
0
    def train(self) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """

        profiling_wrapper.configure(
            capture_start_step=self.config.PROFILING.CAPTURE_START_STEP,
            num_steps_to_capture=self.config.PROFILING.NUM_STEPS_TO_CAPTURE,
        )

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)
        self._setup_actor_critic_agent(ppo_cfg)
        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters())))

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            self.obs_space,
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
        )
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)
        batch = apply_obs_transforms_batch(batch, self.obs_transforms)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            reward=torch.zeros(self.envs.num_envs, 1),
        )
        window_episode_stats: DefaultDict[str, deque] = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES
                                             ),  # type: ignore
        )

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:
            for update in range(self.config.NUM_UPDATES):
                profiling_wrapper.on_start_step()
                profiling_wrapper.range_push("train update")

                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()  # type: ignore

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                profiling_wrapper.range_push("rollouts loop")
                for _step in range(ppo_cfg.num_steps):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps
                profiling_wrapper.range_pop()  # rollouts loop

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time

                for k, v in running_episode_stats.items():
                    window_episode_stats[k].append(v.clone())

                deltas = {
                    k:
                    ((v[-1] -
                      v[0]).sum().item() if len(v) > 1 else v[0].sum().item())
                    for k, v in window_episode_stats.items()
                }
                deltas["count"] = max(deltas["count"], 1.0)

                writer.add_scalar("reward", deltas["reward"] / deltas["count"],
                                  count_steps)

                # Check to see if there are any metrics
                # that haven't been logged yet
                metrics = {
                    k: v / deltas["count"]
                    for k, v in deltas.items() if k not in {"reward", "count"}
                }
                if len(metrics) > 0:
                    writer.add_scalars("metrics", metrics, count_steps)

                losses = [value_loss, action_loss]
                writer.add_scalars(
                    "losses",
                    {k: l
                     for l, k in zip(losses, ["value", "policy"])},
                    count_steps,
                )

                # log stats
                if update > 0 and update % self.config.LOG_INTERVAL == 0:
                    logger.info("update: {}\tfps: {:.3f}\t".format(
                        update, count_steps / (time.time() - t_start)))

                    logger.info(
                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(update, env_time, pth_time,
                                            count_steps))

                    logger.info("Average window size: {}  {}".format(
                        len(window_episode_stats["count"]),
                        "  ".join("{}: {:.3f}".format(k, v / deltas["count"])
                                  for k, v in deltas.items() if k != "count"),
                    ))

                # checkpoint model
                if update % self.config.CHECKPOINT_INTERVAL == 0:
                    self.save_checkpoint(f"ckpt.{count_checkpoints}.pth",
                                         dict(step=count_steps))
                    count_checkpoints += 1

                profiling_wrapper.range_pop()  # train update

            self.envs.close()
Exemple #17
0
def train(
    run_name: str,
    # Data
    train_filepath: str = CSNJS_TRAIN_FILEPATH,
    eval_filepath: str = CSNJS_VALID_FILEPATH,
    spm_filepath: str = SPM_UNIGRAM_FILEPATH,
    program_mode="identity",
    eval_program_mode="identity",
    label_mode="identifier",
    num_workers=1,
    limit_dataset_size=-1,
    # Model
    model_type="transformer",
    n_decoder_layers=4,
    d_model: int = 512,
    resume_path: str = "",
    resume_encoder_name: str = "encoder_q",  # encoder_q, encoder_k, encoder
    resume_project: bool = False,
    # Optimization
    train_decoder_only: bool = False,
    num_epochs: int = 50,
    save_every: int = 2,
    batch_size: int = 256,
    lr: float = 8e-4,
    adam_beta1: float = 0.9,
    adam_beta2: float = 0.98,
    use_lr_warmup: bool = True,
    loss_type="nll_token",  # nll_token or nll_sequence
    # Loss
    subword_regularization_alpha: float = 0,
    # Computational
    use_cuda: bool = True,
    auto_test: bool = True,
    seed: int = 0,
):
    """Train model"""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    run_dir = RUN_DIR / run_name
    run_dir.mkdir(exist_ok=True, parents=True)
    logger.add(str((run_dir / "train.log").resolve()))
    logger.info(f"Saving logs, model checkpoints to {run_dir}")
    config = locals()
    logger.info(f"Config: {config}")
    wandb.init(name=run_name,
               config=config,
               job_type="training",
               project="identifier-prediction",
               entity="ml4code")

    if use_cuda:
        assert torch.cuda.is_available(
        ), "CUDA not available. Check env configuration, or pass --use_cuda False"

    train_augmentations = [
        {
            "fn": "sample_lines",
            "line_length_pct": 0.5
        },
        {
            "fn": "insert_var_declaration",
            "prob": 0.5
        },
        {
            "fn": "rename_variable",
            "prob": 0.5
        },
    ]
    sp = spm.SentencePieceProcessor()
    sp.Load(spm_filepath)
    pad_id = sp.PieceToId("[PAD]")

    # Create training dataset and dataloader
    logger.info(f"Training data path {train_filepath}")
    train_dataset = get_csnjs_dataset(train_filepath,
                                      label_mode=label_mode,
                                      limit_size=limit_dataset_size)
    logger.info(f"Training dataset size: {len(train_dataset)}")
    train_loader = javascript_dataloader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        augmentations=train_augmentations,
        sp=sp,
        program_mode=program_mode,
        subword_regularization_alpha=subword_regularization_alpha,
    )

    # Create eval dataset and dataloader
    logger.info(f"Eval data path {eval_filepath}")
    eval_dataset = get_csnjs_dataset(eval_filepath,
                                     label_mode=label_mode,
                                     limit_size=limit_dataset_size)
    logger.info(f"Eval dataset size: {len(eval_dataset)}")
    eval_loader = javascript_dataloader(
        eval_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        augmentations=[],
        sp=sp,
        program_mode=eval_program_mode,
        subword_regularization_alpha=subword_regularization_alpha,
    )

    # Create model
    pad_id = sp.PieceToId("[PAD]")
    if model_type == "transformer":
        model = TransformerModel(n_tokens=sp.GetPieceSize(),
                                 pad_id=pad_id,
                                 n_decoder_layers=n_decoder_layers,
                                 d_model=d_model)
        logger.info(
            f"Created TransformerModel with {count_parameters(model)} params")
    elif model_type == "lstm":
        model = Seq2SeqLSTM(n_tokens=sp.GetPieceSize(),
                            pad_id=pad_id,
                            d_model=d_model)
        logger.info(
            f"Created Seq2SeqLSTM with {count_parameters(model)} params")

    # Load checkpoint
    if resume_path:
        logger.info(
            f"Resuming training from checkpoint {resume_path}, resume_encoder_name={resume_encoder_name}"
        )
        checkpoint = torch.load(resume_path)
        pretrained_state_dict = checkpoint["model_state_dict"]
        encoder_state_dict = {}
        assert resume_encoder_name in ["encoder_k", "encoder_q", "encoder"]

        for key, value in pretrained_state_dict.items():
            if key.startswith(resume_encoder_name +
                              ".") and "project_layer" not in key:
                remapped_key = key[len(resume_encoder_name + "."):]
                logger.debug(
                    f"Remapping checkpoint key {key} to {remapped_key}. Value mean: {value.mean().item()}"
                )
                encoder_state_dict[remapped_key] = value
            if key.startswith(
                    resume_encoder_name +
                    ".") and "project_layer.0." in key and resume_project:
                remapped_key = key[len(resume_encoder_name + "."):]
                logger.debug(
                    f"Remapping checkpoint project key {key} to {remapped_key}. Value mean: {value.mean().item()}"
                )
                encoder_state_dict[remapped_key] = value
        model.encoder.load_state_dict(encoder_state_dict, strict=False)
        logger.info(f"Loaded state dict from {resume_path}")
        logger.info(f"Loaded keys: {encoder_state_dict.keys()}")

    # Set up optimizer
    model = nn.DataParallel(model)
    model = model.cuda() if use_cuda else model
    wandb.watch(model, log="all")
    params = model.module.decoder.parameters(
    ) if train_decoder_only else model.parameters()
    optimizer = torch.optim.Adam(params,
                                 lr=lr,
                                 betas=(adam_beta1, adam_beta2),
                                 eps=1e-9)
    if use_lr_warmup:
        scheduler = get_linear_schedule_with_warmup(
            optimizer, 5000,
            len(train_loader) * num_epochs)
    else:
        scheduler = LambdaLR(optimizer, lr_lambda=lambda x: 1.0)

    global_step = 0
    min_eval_loss = float("inf")
    for epoch in tqdm.trange(1,
                             num_epochs + 1,
                             desc="training",
                             unit="epoch",
                             leave=False):
        logger.info(f"Starting epoch {epoch}\n")
        if train_decoder_only:
            model.module.encoder.eval()
            model.module.decoder.train()
        else:
            model.train()
        pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}")
        for X, Y, X_lengths, Y_lengths in pbar:
            if use_cuda:
                X = X.cuda()
                Y = Y.cuda()
                X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda()
            optimizer.zero_grad()
            # NOTE: X and Y are [B, max_seq_len] tensors (batch first)
            logits = model(X, Y[:, :-1], X_lengths, Y_lengths)
            if loss_type == "nll_sequence":
                loss = F.cross_entropy(logits.transpose(1, 2),
                                       Y[:, 1:],
                                       ignore_index=pad_id,
                                       reduction='sum')
                loss = loss / X.size(
                    0
                )  # Average over num sequences, not target sequence lengths
                # Thus, minimize bits per sequence.
            elif loss_type == "nll_token":
                loss = F.cross_entropy(
                    logits.transpose(1, 2),
                    Y[:, 1:],
                    ignore_index=pad_id,
                )
            loss.backward()
            optimizer.step()
            scheduler.step()

            # Log loss
            global_step += 1
            wandb.log(
                {
                    "epoch": epoch,
                    f"label-{label_mode}/train_loss": loss.item(),
                    "lr": scheduler.get_last_lr()[0]
                },
                step=global_step)
            pbar.set_description(f"epoch {epoch} loss {loss.item():.4f}")

        # Evaluate
        logger.info(
            f"Evaluating model after epoch {epoch} ({global_step} steps)...")
        max_decode_len = 20 if label_mode == "identifier" else 200
        eval_loss = _evaluate(model,
                              eval_loader,
                              sp,
                              use_cuda=use_cuda,
                              max_decode_len=max_decode_len,
                              loss_type=loss_type)
        logger.info(
            f"Evaluation loss after epoch {epoch} ({global_step} steps): {eval_loss:.4f}"
        )
        wandb.log({
            "epoch": epoch,
            f"label-{label_mode}/eval_loss": eval_loss
        },
                  step=global_step)

        # Save checkpoint
        if save_every and epoch % save_every == 0 or eval_loss < min_eval_loss:
            checkpoint = {
                "model_state_dict": model.module.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "epoch": epoch,
                "global_step": global_step,
                "config": config,
                "eval_loss": eval_loss,
            }
            if eval_loss < min_eval_loss:
                logger.info(
                    f"New best evaluation loss: prev {min_eval_loss:.4f} > new {eval_loss:.4f}"
                )
                min_eval_loss = eval_loss
                model_file = run_dir / "ckpt_best.pth"
            else:
                model_file = run_dir / f"ckpt_ep{epoch:04d}.pth"
            logger.info(f"Saving checkpoint to {model_file}...")
            torch.save(checkpoint, str(model_file.resolve()))
            wandb.save(str(model_file.resolve()))
            logger.info("Done.")

    if auto_test:
        best_ckpt = run_dir / "ckpt_best.pth"
        test(
            str(best_ckpt.resolve()),
            CSNJS_TEST_FILEPATH,
            spm_filepath,
            program_mode,
            label_mode,
            num_workers,
            -1,
            n_decoder_layers=n_decoder_layers,
        )
Exemple #18
0
    def training(  # noqa: C901
        self,
        tproblem,
        hyperparams,
        num_epochs,
        print_train_iter,
        train_log_interval,
        tb_log,
        tb_log_dir,
        **training_params,
    ):
        """Training loop for this runner.

        Args:
            tproblem (deepobs.pytorch.testproblems.testproblem): The testproblem
                instance to train on.
            hyperparams (dict): The optimizer hyperparameters to use for the training.
            num_epochs (int): The number of training epochs.
            print_train_iter (bool): Whether to print the training progress at
                every train_log_interval
            train_log_interval (int): Mini-batch interval for logging.
            tb_log (bool): Whether to use tensorboard logging or not
            tb_log_dir (str): The path where to save tensorboard events.
            **training_params (dict): Kwargs for additional training parameters
                that will be used by the cockpit.

        Returns:
            dict: Output of the training loop
        """
        opt = self._optimizer_class(tproblem.net.parameters(), **hyperparams)

        # Using a LR Scheduler
        lr_sched = training_params["lr_schedule"](num_epochs)
        scheduler = LambdaLR(opt, lr_lambda=lr_sched)

        # COCKPIT: Initialize it #
        logpath = self._get_cockpit_logpath()

        # Integrate BackPACK
        extend_with_access_unreduced_loss(tproblem)

        trainable_params = [
            p for p in tproblem.net.parameters() if p.requires_grad
        ]
        cockpit = Cockpit(trainable_params, quantities=self._quantities)

        plotter = CockpitPlotter(secondary_screen=self._secondary_screen)
        if self._plot_schedule is not None:
            plot_schedule = self._plot_schedule
        else:
            warnings.warn(
                "You are using plot_interval, which will be deprecated. " +
                "Use plot_schedule instead")
            plot_schedule = schedules.linear(training_params["plot_interval"])

        # Lists to log train/test loss and accuracy.
        train_losses = []
        valid_losses = []
        test_losses = []
        train_accuracies = []
        valid_accuracies = []
        test_accuracies = []
        minibatch_train_losses = []

        if tb_log:
            try:
                from torch.utils.tensorboard import SummaryWriter

                summary_writer = SummaryWriter(log_dir=tb_log_dir)
            except ImportError as e:
                warnings.warn(
                    "Not possible to use tensorboard for pytorch. Reason: " +
                    e.msg,
                    RuntimeWarning,
                )
                tb_log = False
        global_step = 0

        for epoch_count in range(num_epochs + 1):
            # Evaluate at beginning of epoch.
            if self._should_eval():
                self.evaluate_all(
                    epoch_count,
                    num_epochs,
                    tproblem,
                    train_losses,
                    valid_losses,
                    test_losses,
                    train_accuracies,
                    valid_accuracies,
                    test_accuracies,
                )

                # COCKPIT: Log already computed quantities #
                cockpit.log(
                    global_step,
                    epoch_count,
                    train_losses[-1],
                    valid_losses[-1],
                    test_losses[-1],
                    train_accuracies[-1],
                    valid_accuracies[-1],
                    test_accuracies[-1],
                    opt.param_groups[0]["lr"],
                )

            # Break from train loop after the last round of evaluation
            if epoch_count == num_epochs:
                break

            # Training #

            # set to training mode
            tproblem.train_init_op()
            batch_count = 0
            while True:
                try:
                    opt.zero_grad()

                    batch_loss, _ = tproblem.get_batch_loss_and_accuracy(
                        reduction="mean")

                    info = {
                        "batch_size":
                        self._extract_batch_size(batch_loss),
                        "individual_losses":
                        self._extract_individual_losses(batch_loss, ),
                        "loss":
                        batch_loss,
                        "optimizer":
                        opt,
                    }

                    # COCKPIT: Use necessary BackPACK extensions and track #
                    with cockpit(global_step, info=info):
                        batch_loss.backward(
                            create_graph=cockpit.create_graph(global_step))

                    if plot_schedule(global_step):
                        plotter.plot(
                            cockpit,
                            savedir=logpath,
                            show_plot=training_params["show_plots"],
                            save_plot=training_params["save_plots"],
                            savename_append="__epoch__" +
                            str(epoch_count).zfill(len(str(num_epochs))) +
                            "__global_step__" + str(global_step).zfill(6),
                        )

                    opt.step()

                    if batch_count % train_log_interval == 0:
                        minibatch_train_losses.append(batch_loss.item())
                        if print_train_iter:
                            print("Epoch {0:d}, step {1:d}: loss {2:g}".format(
                                epoch_count, batch_count, batch_loss))
                        if tb_log:
                            summary_writer.add_scalar("loss",
                                                      batch_loss.item(),
                                                      global_step)

                    batch_count += 1
                    global_step += 1

                    self._maybe_stop_iteration(global_step, batch_count)

                except StopIteration:
                    break

            # Next step in LR Schedule
            scheduler.step()

        # COCKPIT: Write to file and optionally plot after last epoch #
        cockpit.write(logpath)

        if self._enable_plotting:
            plotter.plot(
                cockpit,
                savedir=logpath,
                show_plot=training_params["show_plots"],
                save_plot=training_params["save_final_plot"],
            )

            if training_params["save_animation"]:
                plotter.build_animation(logpath)

        if tb_log:
            summary_writer.close()

        # Put results into output dictionary.
        output = {
            "train_losses": train_losses,
            "valid_losses": valid_losses,
            "test_losses": test_losses,
            "minibatch_train_losses": minibatch_train_losses,
            "train_accuracies": train_accuracies,
            "valid_accuracies": valid_accuracies,
            "test_accuracies": test_accuracies,
        }

        return output
Exemple #19
0
def main():
    global opt
    loss_rec = np.zeros((opt.folds, 100))
    acc_rec = np.zeros((opt.folds, 100))
    #loss_rec = np.load('acc_train.npy')
    #acc_rec = np.load('acc.npy')
    for iteration in range(opt.folds):
        train_dataset = mnist_Dataset(num_of_cross=iteration)

        print('number of train samples is: {0}'.format(len(train_dataset)))
        print('finished loading data')

        if opt.manualSeed is None:
            opt.manualSeed = random.randint(1, 10000)

        if torch.cuda.is_available() and not opt.cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably run with \"cuda: True\""
            )
            torch.manual_seed(opt.manualSeed)
        else:
            if int(opt.ngpu) == 1:
                print('so we use 1 gpu to training')
                print('setting gpu on gpuid {0}'.format(opt.gpu_id))

                if opt.cuda:
                    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
                    torch.cuda.manual_seed(opt.manualSeed)
                    cudnn.benchmark = True
        print('Random Seed: {0}'.format(opt.manualSeed))
        # train data loader
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=opt.batchSize,
                                                   shuffle=True,
                                                   num_workers=int(
                                                       opt.workers))

        # create model
        model = mnist_model.cat_and_dog_resnet()

        if opt.init_model != '':
            print('loading pretrained model from {0}'.format(opt.init_model))
            model.load_state_dict(torch.load(opt.init_model))

        # Contrastive Loss
        #criterion = mnist_model.StableBCELoss()
        criterion = nn.CrossEntropyLoss()

        if opt.cuda:
            print('shift model and criterion to GPU .. ')
            model = model.cuda()
            criterion = criterion.cuda()

        # optimizer
        # optimizer = optim.SGD(model.parameters(), lr=opt.lr,
        #                      momentum=opt.momentum,
        #                      weight_decay=opt.weight_decay)

        optimizer = optim.Adam(model.parameters(), lr=opt.lr)
        # optimizer = optim.SGD(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay, momentum=opt.momentum)
        # optimizer = optim.Adadelta(params=model.parameters(), lr=opt.lr)
        # adjust learning rate every lr_decay_epoch
        lambda_lr = lambda epoch: opt.lr_decay**(
            (epoch + 1) // opt.lr_decay_epoch)  # poly policy
        scheduler = LR_Policy(optimizer, lambda_lr)

        resume_epoch = 0
        acc = test(model, opt, iteration)
        acc_rec[iteration][0] = acc
        acc = test(model, opt, iteration, Training=True)
        loss_rec[iteration][0] = acc
        for epoch in range(resume_epoch, opt.max_epochs):
            #################################
            # train for one epoch
            #################################
            #accuracy = test(model, opt, epoch)
            train(train_loader, model, criterion, optimizer, iteration, opt,
                  epoch)
            scheduler.step()

            ##################################
            # save checkpoints
            ##################################

            # save model every 10 epochs
            accuracy = test(model, opt, iteration)
            acc_rec[iteration][epoch + 1] = accuracy
            np.save('acc.npy', acc_rec)
            accuracy = test(model, opt, iteration, Training=True)
            loss_rec[iteration][epoch + 1] = accuracy
            np.save('acc_train.npy', loss_rec)

            if ((epoch + 1) % opt.epoch_save) == 0:
                path_checkpoint = '{0}/{1}_{3}_epoch{2}.pth'.format(
                    opt.checkpoint_folder, opt.prefix, epoch + 1, iteration)
                utils.save_checkpoint(model.state_dict(), path_checkpoint)
Exemple #20
0
def learn(device, env, nenv, seed, number_timesteps, network, optimizer,
          save_path, save_interval, ob_scale, lr, gamma, grad_norm,
          timesteps_per_batch, ent_coef, vf_coef, **kwargs):
    """
    Paper:
    Mnih V, Badia A P, Mirza M, et al. Asynchronous methods for deep
    reinforcement learning[C]// International Conference on Machine Learning.
    2016: 1928-1937.

    Parameters:
    ----------
    gram_norm (float | None): grad norm
    timesteps_per_batch (int): number of steps per update
    ent_coef (float): policy entropy coefficient in the objective
    vf_coef (float): value function loss coefficient in the objective

    """
    name = '{}_{}'.format(os.path.split(__file__)[-1][:-3], seed)
    logger = get_logger(name)

    policy = build_policy(env, network, estimate_value=True).to(device)
    optimizer = get_optimizer(optimizer, policy.parameters(), lr)
    number_timesteps = number_timesteps // nenv
    generator = _generate(device, env, policy, ob_scale, number_timesteps,
                          gamma, timesteps_per_batch)
    max_iter = number_timesteps // timesteps_per_batch
    scheduler = LambdaLR(optimizer, lambda i_iter: 1 - i_iter / max_iter)

    total_timesteps = 0
    infos = {
        k: deque(maxlen=100)
        for k in ['eplenmean', 'eprewmean', 'pgloss', 'v', 'entropy']
    }
    start_ts = time.time()
    for n_iter in range(1, max_iter + 1):
        scheduler.step()

        batch = generator.__next__()
        b_o, b_a, b_r, b_v_old, info = batch
        for d in info:
            infos['eplenmean'].append(d['l'])
            infos['eprewmean'].append(d['r'])
        total_timesteps += b_o[0].size(0)

        # calculate advantange
        b_logits, b_v = policy(b_o)
        b_v = b_v[:, 0]
        dist = torch.distributions.Categorical(logits=b_logits)
        entropy = dist.entropy().mean()
        b_logp = dist.log_prob(b_a)
        adv = b_r - b_v_old

        # update policy
        vloss = (b_v - b_r).pow(2).mean()
        pgloss = -(adv * b_logp).mean()
        loss = pgloss + vf_coef * vloss - ent_coef * entropy
        optimizer.zero_grad()
        loss.backward()
        if grad_norm is not None:
            nn.utils.clip_grad_norm_(policy.parameters(), grad_norm)
        optimizer.step()

        # record logs
        infos['pgloss'].append(pgloss.item())
        infos['v'].append(vloss.item())
        infos['entropy'].append(entropy.item())
        logger.info('{} Iter {} {}'.format('=' * 10, n_iter, '=' * 10))
        fps = int(total_timesteps / (time.time() - start_ts))
        logger.info('Total timesteps {} FPS {}'.format(total_timesteps, fps))
        for k, v in infos.items():
            v = (sum(v) / len(v)) if v else float('nan')
            logger.info('{}: {:.6f}'.format(k, v))
        if save_interval and n_iter % save_interval == 0:
            torch.save([policy.state_dict(),
                        optimizer.state_dict()],
                       os.path.join(save_path, '{}.{}'.format(name, n_iter)))
Exemple #21
0
    def lr_range_test(self,
                      data_loader,
                      end_lr,
                      num_iter=100,
                      step_mode='exp',
                      alpha=0.05,
                      ax=None):
        # Since the test updates both model and optimizer we need to store
        # their initial states to restore them in the end
        previous_states = {
            'model': deepcopy(self.model.state_dict()),
            'optimizer': deepcopy(self.optimizer.state_dict())
        }
        # Retrieves the learning rate set in the optimizer
        start_lr = self.optimizer.state_dict()['param_groups'][0]['lr']

        # Builds a custom function and corresponding scheduler
        lr_fn = make_lr_fn(start_lr, end_lr, num_iter)
        scheduler = LambdaLR(self.optimizer, lr_lambda=lr_fn)

        # Variables for tracking results and iterations
        tracking = {'loss': [], 'lr': []}
        iteration = 0

        # If there are more iterations than mini-batches in the data loader,
        # it will have to loop over it more than once
        while (iteration < num_iter):
            # That's the typical mini-batch inner loop
            for x_batch, y_batch in data_loader:
                x_batch = x_batch.to(self.device)
                y_batch = y_batch.to(self.device)
                # Step 1
                yhat = self.model(x_batch)
                # Step 2
                loss = self.loss_fn(yhat, y_batch)
                # Step 3
                loss.backward()

                # Here we keep track of the losses (smoothed)
                # and the learning rates
                tracking['lr'].append(scheduler.get_last_lr()[0])
                if iteration == 0:
                    tracking['loss'].append(loss.item())
                else:
                    prev_loss = tracking['loss'][-1]
                    smoothed_loss = alpha * loss.item() + (1 -
                                                           alpha) * prev_loss
                    tracking['loss'].append(smoothed_loss)

                iteration += 1
                # Number of iterations reached
                if iteration == num_iter:
                    break

                # Step 4
                self.optimizer.step()
                scheduler.step()
                self.optimizer.zero_grad()

        # Restores the original states
        self.optimizer.load_state_dict(previous_states['optimizer'])
        self.model.load_state_dict(previous_states['model'])

        if ax is None:
            fig, ax = plt.subplots(1, 1, figsize=(6, 4))
        else:
            fig = ax.get_figure()
        ax.plot(tracking['lr'], tracking['loss'])
        if step_mode == 'exp':
            ax.set_xscale('log')
        ax.set_xlabel('Learning Rate')
        ax.set_ylabel('Loss')
        fig.tight_layout()
        return tracking, fig
Exemple #22
0
    optimizer = SGD(model.parameters(),
                    lr=0.03,
                    momentum=0.9,
                    weight_decay=5e-4)
    lr_scheduler = LambdaLR(optimizer,
                            lr_lambda=lambda i: 0.5 *
                            (math.cos(i * math.pi / epochs) + 1))
    #c = len(memory_data.classes)
    c = 2

    results = {'train_loss': [], 'test_acc@1': [], 'test_acc@5': []}
    save_name_pre = '{}_{}_{}_{}'.format(feature_dim, k, batch_size, epochs)
    if not os.path.exists('results'):
        os.mkdir('results')
    best_acc = 0.0
    # training loop
    for epoch in range(1, epochs + 1):
        train_loss = train(model, train_loader, optimizer)
        results['train_loss'].append(train_loss)
        lr_scheduler.step()
        #test_acc_1, test_acc_5 = test(model, memory_loader, test_loader)
        #results['test_acc@1'].append(test_acc_1)
        #results['test_acc@5'].append(test_acc_5)
        # save statistics
        #data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
        #data_frame.to_csv('results/{}_statistics.csv'.format(save_name_pre), index_label='epoch')
        #if test_acc_1 > best_acc:
        #best_acc = test_acc_1
    torch.save(model.state_dict(),
               'results/{}_model.pth'.format(save_name_pre))
Exemple #23
0
def train(model, tokenizer, train_data, valid_data, args):
    model.train()

    train_dataset = TextDataset(train_data)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=RandomSampler(train_dataset),
                                  batch_size=args.train_batch_size,
                                  num_workers=args.num_workers,
                                  collate_fn=lambda x: collate_fn_bert(
                                      x, tokenizer, args.max_seq_length))

    valid_dataset = TextDataset(valid_data)
    valid_dataloader = DataLoader(valid_dataset,
                                  sampler=SequentialSampler(valid_dataset),
                                  batch_size=args.eval_batch_size,
                                  num_workers=args.num_workers,
                                  collate_fn=lambda x: collate_fn_bert(
                                      x, tokenizer, args.max_seq_length))

    valid_noisy = [x['noisy'] for x in valid_data]
    valid_clean = [x['clean'] for x in valid_data]

    epochs = (args.max_steps - 1) // len(train_dataloader) + 1
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
    #                              betas=eval(args.adam_betas), eps=args.eps,
    #                              weight_decay=args.weight_decay)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    lr_lambda = lambda x: x / args.num_warmup_steps if x <= args.num_warmup_steps else (
        x / args.num_warmup_steps)**-0.5
    scheduler = LambdaLR(optimizer, lr_lambda)

    step = 0
    best_val_gleu = -float("inf")
    meter = Meter()
    for epoch in range(1, epochs + 1):
        for batch in train_dataloader:

            step += 1
            batch = tuple(t.to(args.device) for t in batch)
            noise_input_ids, clean_input_ids, noise_mask, clean_mask = batch
            #print("noise shape: {}, clean shape: {}".format(noise_input_ids.shape, clean_input_ids.shape))

            outputs = model(noise_input_ids,
                            labels=clean_input_ids,
                            attention_mask=noise_mask)
            loss = outputs[0]
            predict_score = outputs[1]

            bsz = clean_input_ids.size(0)
            items = [loss.data.item(), bsz, clean_mask.sum().item()]
            #print("items: ", items)
            meter.add(*items)

            loss.backward()
            if args.max_grad_norm > 0:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         args.max_grad_norm)
            optimizer.step()
            model.zero_grad()
            scheduler.step()

            if step % args.log_interval == 0:
                lr = scheduler.get_lr()[0]
                loss_sent, loss_token = meter.average()

                logger.info(
                    f' [{step:5d}] lr {lr:.6f} | {meter.print_str(True)}')
                nsml.report(step=step,
                            scope=locals(),
                            summary=True,
                            train__lr=lr,
                            train__loss_sent=loss_sent,
                            train__token_ppl=math.exp(loss_token))
                meter.init()

            if step % args.eval_interval == 0:
                start_eval = time.time()
                (val_loss, val_loss_token), valid_str = evaluate_kcBert(
                    model, valid_dataloader, args)
                prediction = correct_kcBert(model,
                                            tokenizer,
                                            valid_noisy,
                                            args,
                                            length_limit=0.1)
                val_em = em(prediction, valid_clean)
                cnt = 0
                for noisy, pred, clean in zip(valid_noisy, prediction,
                                              valid_clean):
                    print(f'[{noisy}], [{pred}], [{clean}]')
                    # 10개만 출력하기
                    cnt += 1
                    if cnt == 20:
                        break
                # print("len of prediction: {}, len of valid_clean: {}", len(prediction), len(valid_clean))
                val_gleu = gleu(prediction, valid_clean)

                logger.info('-' * 89)
                logger.info(
                    f' [{step:6d}] valid | {valid_str} | em {val_em:5.2f} | gleu {val_gleu:5.2f}'
                )
                logger.info('-' * 89)
                nsml.report(step=step,
                            scope=locals(),
                            summary=True,
                            valid__loss_sent=val_loss,
                            valid__token_ppl=math.exp(val_loss_token),
                            valid__em=val_em,
                            valid__gleu=val_gleu)

                if val_gleu > best_val_gleu:
                    best_val_gleu = val_gleu
                    nsml.save("best")
                meter.start += time.time() - start_eval

            if step >= args.max_steps:
                break
        if step >= args.max_steps:
            break
Exemple #24
0
class Model:

    def __init__(self, device, num_steps):

        # in and out channels
        # for the generator:
        a, b = 1, 3

        def weights_init(m):
            if isinstance(m, nn.Conv2d):
                init.normal_(m.weight, std=0.02)
                if m.bias is not None:
                    init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                init.ones_(m.weight)
                init.zeros_(m.bias)

        G = Generator(a, b).train()
        self.G = G.apply(weights_init).to(device)

        # it turns out that this is important
        init.normal_(self.G.end[0].weight, std=1e-4)

        def lambda_rule(i):
            decay = num_steps // 4
            m = 1.0 if i < decay else 1.0 - (i - decay) / (num_steps - decay)
            return max(m, 1e-3)

        self.optimizer = optim.Adam(self.G.parameters(), lr=2e-4, betas=(0.9, 0.999))
        self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambda_rule)

        self.cp_loss = CPLoss()
        self.gp_loss = GPLoss()

        if USE_FLOAT16:
            self.G, self.optimizer = amp.initialize(self.G, self.optimizer, opt_level='O2')

        # a copy for exponential moving average
        self.G_ema = copy.deepcopy(self.G)

    def train_step(self, A, B):
        """
        The input tensors represent images
        with pixel values in [0, 1] range.

        Arguments:
            A: a float tensor with shape [n, a, h, w].
            B: a float tensor with shape [n, b, h, w].
        Returns:
            a dict with float numbers.
        """
        self.optimizer.zero_grad()

        B_restored = self.G(A)
        # it has shape [n, b, h, w]

        cp_loss = self.cp_loss(B_restored, B)
        gp_loss = self.gp_loss(B_restored, B)
        reconstruction_loss = cp_loss + gp_loss

        if not USE_FLOAT16:
            reconstruction_loss.backward()
        else:
            with amp.scale_loss(reconstruction_loss, self.optimizer) as loss_scaled:
                loss_scaled.backward()

        self.optimizer.step()
        self.scheduler.step()

        # running average of weights
        accumulate(self.G_ema, self.G)

        loss_dict = {
            'total_loss': reconstruction_loss.item(),
            'cp_loss': cp_loss.item(),
            'gp_loss': gp_loss.item()
        }
        return loss_dict

    def save_model(self, model_path):
        torch.save(self.G.state_dict(), model_path + '_generator.pth')
        torch.save(self.G_ema.state_dict(), model_path + '_generator_ema.pth')
Exemple #25
0
class TriviaQA(pl.LightningModule):
    def __init__(self, args):
        super(TriviaQA, self).__init__()
        self.args = args
        self.hparams = args

        self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
        self.tokenizer.model_max_length = self.args.max_seq_len
        self.model = self.load_model()
        self.num_labels = 2
        self.qa_outputs = torch.nn.Linear(self.model.config.hidden_size,
                                          self.num_labels)
        self.train_dataloader_object = self.val_dataloader_object = self.test_dataloader_object = None

    def load_model(self):
        model = Longformer.from_pretrained(self.args.model_path)
        for layer in model.encoder.layer:
            layer.attention.self.attention_mode = self.args.attention_mode
            self.args.attention_window = layer.attention.self.attention_window

        print("Loaded model with config:")
        print(model.config)

        for p in model.parameters():
            p.requires_grad_(True)
        model.train()
        return model

    def forward(self, input_ids, attention_mask, segment_ids, start_positions,
                end_positions):
        question_end_index = self._get_question_end_index(input_ids)
        # Each batch is one document, and each row of the batch is a chunck of the document.
        # Make sure all rows have the same question length.
        assert (question_end_index[0].float() ==
                question_end_index.float().mean()).item()

        # local attention everywhere
        attention_mask = torch.ones(input_ids.shape,
                                    dtype=torch.long,
                                    device=input_ids.device)
        # global attention for the question tokens
        attention_mask[:, :question_end_index.item()] = 2

        # sliding_chunks implemenation of selfattention requires that seqlen is multiple of window size
        input_ids, attention_mask = pad_to_window_size(
            input_ids, attention_mask, self.args.attention_window,
            self.tokenizer.pad_token_id)

        sequence_output = self.model(input_ids,
                                     attention_mask=attention_mask)[0]

        # The pretrained TriviaQA model wasn't trained with padding, so remove padding tokens
        # before computing loss and decoding.
        padding_len = input_ids[0].eq(self.tokenizer.pad_token_id).sum()
        if padding_len > 0:
            sequence_output = sequence_output[:, :-padding_len]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        outputs = (
            start_logits,
            end_logits,
        )
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)

            if not self.args.regular_softmax_loss:
                # loss function suggested in section 2.2 here https://arxiv.org/pdf/1710.10723.pdf
                # NOTE: this returns sum of losses, not mean, so loss won't be normalized across different batch sizes
                # but batch size is always 1, so this is not a problem
                start_loss = self.or_softmax_cross_entropy_loss_one_doc(
                    start_logits, start_positions, ignore_index=-1)
                end_loss = self.or_softmax_cross_entropy_loss_one_doc(
                    end_logits, end_positions, ignore_index=-1)
            else:
                loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-1)
                start_positions = start_positions[:, 0:1]
                end_positions = end_positions[:, 0:1]
                start_loss = loss_fct(start_logits, start_positions[:, 0])
                end_loss = loss_fct(end_logits, end_positions[:, 0])

            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss, ) + outputs

        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)

    def or_softmax_cross_entropy_loss_one_doc(self,
                                              logits,
                                              target,
                                              ignore_index=-1,
                                              dim=-1):
        """loss function suggested in section 2.2 here https://arxiv.org/pdf/1710.10723.pdf"""
        assert logits.ndim == 2
        assert target.ndim == 2
        assert logits.size(0) == target.size(0)

        # with regular CrossEntropyLoss, the numerator is only one of the logits specified by the target
        # here, the numerator is the sum of a few potential targets, where some of them is the correct answer

        # compute a target mask
        target_mask = target == ignore_index
        # replaces ignore_index with 0, so `gather` will select logit at index 0 for the msked targets
        masked_target = target * (1 - target_mask.long())
        # gather logits
        gathered_logits = logits.gather(dim=dim, index=masked_target)
        # Apply the mask to gathered_logits. Use a mask of -inf because exp(-inf) = 0
        gathered_logits[target_mask] = float('-inf')

        # each batch is one example
        gathered_logits = gathered_logits.view(1, -1)
        logits = logits.view(1, -1)

        # numerator = log(sum(exp(gathered logits)))
        log_score = torch.logsumexp(gathered_logits, dim=dim, keepdim=False)
        # denominator = log(sum(exp(logits)))
        log_norm = torch.logsumexp(logits, dim=dim, keepdim=False)

        # compute the loss
        loss = -(log_score - log_norm)

        # some of the examples might have a loss of `inf` when `target` is all `ignore_index`.
        # remove those from the loss before computing the sum. Use sum instead of mean because
        # it is easier to compute
        return loss[~torch.isinf(loss)].sum()

    def training_step(self, batch, batch_nb):
        input_ids, input_mask, segment_ids, subword_starts, subword_ends, qids, aliases = batch
        output = self.forward(input_ids, input_mask, segment_ids,
                              subword_starts, subword_ends)
        loss = output[0]
        lr = loss.new_zeros(
            1) + self.trainer.optimizers[0].param_groups[0]['lr']
        tensorboard_logs = {
            'train_loss': loss,
            'lr': lr,
            'input_size': input_ids.numel(),
            'mem': torch.cuda.memory_allocated(input_ids.device) / 1024**3
        }
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        input_ids, input_mask, segment_ids, subword_starts, subword_ends, qids, aliases = batch
        output = self.forward(input_ids, input_mask, segment_ids,
                              subword_starts, subword_ends)
        loss, start_logits, end_logits = output[:3]
        answers = self.decode(input_ids, start_logits, end_logits)

        # each batch is one document
        answers = sorted(answers, key=lambda x: x['score'], reverse=True)[0:1]
        qids = [qids]
        aliases = [aliases]

        f1_scores = [
            evaluation_utils.metric_max_over_ground_truths(
                evaluation_utils.f1_score, answer['text'], aliase_list)
            for answer, aliase_list in zip(answers, aliases)
        ]
        # TODO: if slow, skip em_scores, and use (f1_score == 1.0) instead
        em_scores = [
            evaluation_utils.metric_max_over_ground_truths(
                evaluation_utils.exact_match_score, answer['text'],
                aliase_list) for answer, aliase_list in zip(answers, aliases)
        ]
        answer_scores = [answer['score']
                         for answer in answers]  # start_logit + end_logit
        assert len(answer_scores) == len(f1_scores) == len(em_scores) == len(
            qids) == len(aliases) == 1

        # TODO: delete this metric
        pred_subword_starts = start_logits.argmax(dim=1)
        pred_subword_ends = end_logits.argmax(dim=1)
        exact_match = (subword_ends[:, 0].squeeze(dim=-1) == pred_subword_ends).float() *  \
                      (subword_starts[:, 0].squeeze(dim=-1) == pred_subword_starts).float()

        return {
            'vloss': loss,
            'vem': exact_match.mean(),
            'qids': qids,
            'answer_scores': answer_scores,
            'f1': f1_scores,
            'em': em_scores
        }

    def _get_question_end_index(self, input_ids):
        eos_token_indices = (
            input_ids == self.tokenizer.eos_token_id).nonzero()
        assert eos_token_indices.ndim == 2
        assert eos_token_indices.size(0) == 2 * input_ids.size(0)
        assert eos_token_indices.size(1) == 2
        return eos_token_indices.view(input_ids.size(0), 2, 2)[:, 0, 1]

    def decode(self, input_ids, start_logits, end_logits):
        # find beginning of document
        question_end_index = self._get_question_end_index(input_ids)

        # bsz x seqlen => bsz x n_best_size
        start_logits_indices = start_logits.topk(k=self.args.n_best_size,
                                                 dim=-1).indices
        end_logits_indices = end_logits.topk(k=self.args.n_best_size,
                                             dim=-1).indices

        answers = []
        # This loop can't be vectorized, so loop over each example in the batch separetly
        for i in range(start_logits_indices.size(0)):  # bsz
            potential_answers = []
            for start_logit_index in start_logits_indices[i]:  # n_best_size
                for end_logit_index in end_logits_indices[i]:  # n_best_size
                    if start_logit_index <= question_end_index[i]:
                        continue
                    if end_logit_index <= question_end_index[i]:
                        continue
                    if start_logit_index > end_logit_index:
                        continue
                    answer_len = end_logit_index - start_logit_index + 1
                    if answer_len > self.args.max_answer_length:
                        continue
                    potential_answers.append({
                        'start':
                        start_logit_index,
                        'end':
                        end_logit_index,
                        'start_logit':
                        start_logits[i][start_logit_index].item(),
                        'end_logit':
                        end_logits[i][end_logit_index].item()
                    })
            sorted_answers = sorted(potential_answers,
                                    key=lambda x:
                                    (x['start_logit'] + x['end_logit']),
                                    reverse=True)
            if len(sorted_answers) == 0:
                answers.append({'text': 'NoAnswerFound', 'score': -1000000})
            else:
                answer = sorted_answers[0]
                answer_token_ids = input_ids[i,
                                             answer['start']:answer['end'] + 1]
                answer_tokens = self.tokenizer.convert_ids_to_tokens(
                    answer_token_ids.tolist())
                text = self.tokenizer.convert_tokens_to_string(answer_tokens)
                score = answer['start_logit'] + answer['end_logit']
                answers.append({'text': text, 'score': score})
        return answers

    def sync_list_across_gpus(self, l, device, dtype):
        l_tensor = torch.tensor(l, device=device, dtype=dtype)
        gather_l_tensor = [
            torch.ones_like(l_tensor) for _ in range(self.trainer.world_size)
        ]
        torch.distributed.all_gather(gather_l_tensor, l_tensor)
        return torch.cat(gather_l_tensor).tolist()

    def validation_end(self, outputs):
        avg_loss = torch.stack([x['vloss'] for x in outputs]).mean()
        avg_em = torch.stack([x['vem'] for x in outputs]).mean()
        string_qids = [item for sublist in outputs for item in sublist['qids']]
        int_qids = [
            self.val_dataloader_object.dataset.val_qid_string_to_int_map[qid]
            for qid in string_qids
        ]
        answer_scores = [
            item for sublist in outputs for item in sublist['answer_scores']
        ]
        f1_scores = [item for sublist in outputs for item in sublist['f1']]
        em_scores = [item for sublist in outputs for item in sublist['em']]
        print(
            f'before sync --> sizes: {len(int_qids)}, {len(answer_scores)}, {len(f1_scores)}, {len(em_scores)}'
        )
        if self.trainer.use_ddp:
            torch.distributed.all_reduce(avg_loss,
                                         op=torch.distributed.ReduceOp.SUM)
            avg_loss /= self.trainer.world_size
            torch.distributed.all_reduce(avg_em,
                                         op=torch.distributed.ReduceOp.SUM)
            avg_em /= self.trainer.world_size

            int_qids = self.sync_list_across_gpus(int_qids, avg_loss.device,
                                                  torch.int)
            answer_scores = self.sync_list_across_gpus(answer_scores,
                                                       avg_loss.device,
                                                       torch.float)
            f1_scores = self.sync_list_across_gpus(f1_scores, avg_loss.device,
                                                   torch.float)
            em_scores = self.sync_list_across_gpus(em_scores, avg_loss.device,
                                                   torch.int)
        print(
            f'after sync --> sizes: {len(int_qids)}, {len(answer_scores)}, {len(f1_scores)}, {len(em_scores)}'
        )

        # Because of having multiple documents per questions, some questions might have multiple corresponding answers
        # Here, we only keep the answer with the highest answer_score
        qa_with_duplicates = defaultdict(list)
        for qid, answer_score, f1_score, em_score in zip(
                int_qids, answer_scores, f1_scores, em_scores):
            qa_with_duplicates[qid].append({
                'answer_score': answer_score,
                'f1': f1_score,
                'em': em_score
            })
        f1_scores = []
        em_scores = []
        for qid, answer_metrics in qa_with_duplicates.items():
            top_answer = sorted(answer_metrics,
                                key=lambda x: x['answer_score'],
                                reverse=True)[0]
            f1_scores.append(top_answer['f1'])
            em_scores.append(top_answer['em'])
        avg_val_f1 = sum(f1_scores) / len(f1_scores)
        avg_val_em = sum(em_scores) / len(em_scores)

        logs = {
            'val_loss': avg_loss,
            'val_em': avg_em,
            'avg_val_f1': avg_val_f1,
            'avg_val_em': avg_val_em
        }

        return {'avg_val_loss': avg_loss, 'log': logs, 'progress_bar': logs}

    def test_step(self, batch, batch_nb):
        input_ids, input_mask, segment_ids, subword_starts, subword_ends, qids, aliases = batch
        output = self.forward(input_ids, input_mask, segment_ids,
                              subword_starts, subword_ends)
        loss, start_logits, end_logits = output[:3]
        answers = self.decode(input_ids, start_logits, end_logits)

        # each batch is one document
        answers = sorted(answers, key=lambda x: x['score'], reverse=True)[0:1]
        qids = [qids]
        assert len(answers) == len(qids)
        return {'qids': qids, 'answers': answers}

    def test_end(self, outputs):
        qids = [item for sublist in outputs for item in sublist['qids']]
        answers = [item for sublist in outputs for item in sublist['answers']]

        qa_with_duplicates = defaultdict(list)
        for qid, answer in zip(qids, answers):
            qa_with_duplicates[qid].append({
                'answer_score': answer['score'],
                'answer_text': answer['text'],
            })

        qid_to_answer_text = {}
        for qid, answer_metrics in qa_with_duplicates.items():
            top_answer = sorted(answer_metrics,
                                key=lambda x: x['answer_score'],
                                reverse=True)[0]
            qid_to_answer_text[qid] = top_answer['answer_text']

        with open('predictions.json', 'w') as f:
            json.dump(qid_to_answer_text, f)

        return {'count': len(qid_to_answer_text)}

    def optimizer_step(self,
                       current_epoch,
                       batch_nb,
                       optimizer,
                       optimizer_i,
                       second_order_closure=None):
        optimizer.step()
        optimizer.zero_grad()
        self.scheduler.step(self.global_step)

    def configure_optimizers(self):
        def lr_lambda(current_step):
            if current_step < self.args.warmup:
                return float(current_step) / float(max(1, self.args.warmup))
            return max(
                0.0,
                float(self.args.steps - current_step) /
                float(max(1, self.args.steps - self.args.warmup)))

        optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr)
        self.scheduler = LambdaLR(
            optimizer, lr_lambda, last_epoch=-1
        )  # scheduler is not saved in the checkpoint, but global_step is, which is enough to restart
        self.scheduler.step(self.global_step)

        return optimizer

    @pl.data_loader
    def train_dataloader(self):
        if self.train_dataloader_object is not None:
            return self.train_dataloader_object
        dataset = TriviaQADataset(
            file_path=self.args.train_dataset,
            tokenizer=self.tokenizer,
            max_seq_len=self.args.max_seq_len,
            max_doc_len=self.args.max_doc_len,
            doc_stride=self.args.doc_stride,
            max_num_answers=self.args.max_num_answers,
            max_question_len=self.args.max_question_len,
            ignore_seq_with_no_answers=self.args.ignore_seq_with_no_answers)
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset) if self.trainer.use_ddp else None
        dl = DataLoader(dataset,
                        batch_size=1,
                        shuffle=(sampler is None),
                        num_workers=self.args.num_workers,
                        sampler=sampler,
                        collate_fn=TriviaQADataset.collate_one_doc_and_lists)
        self.train_dataloader_object = dl
        return self.train_dataloader_object

    @pl.data_loader
    def val_dataloader(self):
        if self.val_dataloader_object is not None:
            return self.val_dataloader_object
        dataset = TriviaQADataset(file_path=self.args.dev_dataset,
                                  tokenizer=self.tokenizer,
                                  max_seq_len=self.args.max_seq_len,
                                  max_doc_len=self.args.max_doc_len,
                                  doc_stride=self.args.doc_stride,
                                  max_num_answers=self.args.max_num_answers,
                                  max_question_len=self.args.max_question_len,
                                  ignore_seq_with_no_answers=False
                                  )  # evaluation data should keep all examples
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset) if self.trainer.use_ddp else None
        dl = DataLoader(dataset,
                        batch_size=1,
                        shuffle=(sampler is None),
                        num_workers=self.args.num_workers,
                        sampler=sampler,
                        collate_fn=TriviaQADataset.collate_one_doc_and_lists)
        self.val_dataloader_object = dl
        return self.val_dataloader_object

    @pl.data_loader
    def test_dataloader(self):
        if self.test_dataloader_object is not None:
            return self.test_dataloader_object
        dataset = TriviaQADataset(file_path=self.args.dev_dataset,
                                  tokenizer=self.tokenizer,
                                  max_seq_len=self.args.max_seq_len,
                                  max_doc_len=self.args.max_doc_len,
                                  doc_stride=self.args.doc_stride,
                                  max_num_answers=self.args.max_num_answers,
                                  max_question_len=self.args.max_question_len,
                                  ignore_seq_with_no_answers=False
                                  )  # evaluation data should keep all examples

        dl = DataLoader(dataset,
                        batch_size=1,
                        shuffle=False,
                        num_workers=self.args.num_workers,
                        sampler=None,
                        collate_fn=TriviaQADataset.collate_one_doc_and_lists)
        self.test_dataloader_object = dl
        return self.test_dataloader_object

    def configure_ddp(self, model, device_ids):
        model = LightningDistributedDataParallel(model,
                                                 device_ids=device_ids,
                                                 find_unused_parameters=True)
        return model

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        parser.add_argument("--save_dir", type=str, default='triviaqa')
        parser.add_argument("--save_prefix", type=str, required=True)
        parser.add_argument("--train_dataset",
                            type=str,
                            required=False,
                            help="Path to the training squad-format")
        parser.add_argument("--dev_dataset",
                            type=str,
                            required=True,
                            help="Path to the dev squad-format")
        parser.add_argument("--batch_size",
                            type=int,
                            default=8,
                            help="Batch size")
        parser.add_argument(
            "--gpus",
            type=str,
            default='0',
            help=
            "Comma separated list of gpus. Default is gpu 0. To use CPU, use --gpus "
            " ")
        parser.add_argument("--warmup",
                            type=int,
                            default=200,
                            help="Number of warmup steps")
        parser.add_argument("--lr",
                            type=float,
                            default=0.0001,
                            help="Maximum learning rate")
        parser.add_argument(
            "--val_every",
            type=float,
            default=0.2,
            help="Number of training steps between validations")
        parser.add_argument("--val_percent_check",
                            default=1.00,
                            type=float,
                            help='Percent of validation data used')
        parser.add_argument("--num_workers",
                            type=int,
                            default=4,
                            help="Number of data loader workers")
        parser.add_argument("--seed", type=int, default=1234, help="Seed")
        parser.add_argument("--epochs",
                            type=int,
                            default=30,
                            help="Number of epochs")
        parser.add_argument(
            "--max_seq_len",
            type=int,
            default=4096,
            help="Maximum length of seq passed to the transformer model")
        parser.add_argument(
            "--max_doc_len",
            type=int,
            default=4096,
            help="Maximum number of wordpieces of the input document")
        parser.add_argument(
            "--max_num_answers",
            type=int,
            default=64,
            help="Maximum number of answer spans per document (64 => 94%)")
        parser.add_argument("--max_question_len",
                            type=int,
                            default=55,
                            help="Maximum length of the question")
        parser.add_argument(
            "--doc_stride",
            type=int,
            default=-1,
            help=
            "Overlap between document chunks. Use -1 to only use the first chunk"
        )
        parser.add_argument(
            "--ignore_seq_with_no_answers",
            action='store_true',
            help=
            "each example should have at least one answer. Default is False")
        parser.add_argument("--disable_checkpointing",
                            action='store_true',
                            help="No logging or checkpointing")
        parser.add_argument(
            "--n_best_size",
            type=int,
            default=20,
            help="Number of answer candidates. Used at decoding time")
        parser.add_argument(
            "--max_answer_length",
            type=int,
            default=30,
            help="maximum num of wordpieces/answer. Used at decoding time")
        parser.add_argument(
            "--regular_softmax_loss",
            action='store_true',
            help=
            "IF true, use regular softmax. Default is using ORed softmax loss")
        parser.add_argument("--test",
                            action='store_true',
                            help="Test only, no training")
        parser.add_argument("--model_path",
                            type=str,
                            required=True,
                            help="Path to the checkpoint directory")
        parser.add_argument("--no_progress_bar",
                            action='store_true',
                            help="no progress bar. Good for printing")
        parser.add_argument(
            "--attention_mode",
            type=str,
            choices=['tvm', 'sliding_chunks'],
            default='sliding_chunks',
            help='Which implementation of selfattention to use')
        parser.add_argument(
            "--fp32",
            action='store_true',
            help="default is fp16. Use --fp32 to switch to fp32")

        return parser
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          jmmd_loss: JointMultipleKernelMaximumMeanDiscrepancy, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':5.4f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    jmmd_loss.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = jmmd_loss((f_s, F.softmax(y_s, dim=1)),
                                  (f_t, F.softmax(y_t, dim=1)))
        loss = cls_loss + transfer_loss * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #27
0
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_adv: DomainAdversarialLoss, class_weight_module: AutomaticUpdateClassWeightModule,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    partial_classes_weights = AverageMeter('Partial Weight', ':3.1f')
    non_partial_classes_weights = AverageMeter('Non-partial Weight', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, cls_accs, domain_accs, tgt_accs, partial_classes_weights, non_partial_classes_weights],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        cls_loss = F.cross_entropy(y_s, labels_s, class_weight_module.get_class_weight_for_cross_entropy_loss())
        w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss(labels_s)
        transfer_loss = domain_adv(f_s, f_t, w_s, w_t)
        class_weight_module.step()
        partial_classes_weight, non_partial_classes_weight = class_weight_module.get_partial_classes_weight()
        domain_acc = domain_adv.domain_discriminator_accuracy
        loss = cls_loss + transfer_loss * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        domain_accs.update(domain_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))
        partial_classes_weights.update(partial_classes_weight.item(), x_s.size(0))
        non_partial_classes_weights.update(non_partial_classes_weight.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #28
0
def train(model: torch.nn.Module,
          train_dls: List[DataLoader],
          optimizer: torch.optim.Optimizer,
          scheduler: LambdaLR,
          validation_evaluator: MultiDatasetClassificationEvaluator,
          n_epochs: int,
          device: AnyStr,
          log_interval: int = 1,
          patience: int = 10,
          model_dir: str = "wandb_local",
          gradient_accumulation: int = 1,
          domain_name: str = ''):
    #best_loss = float('inf')
    best_f1 = 0.0
    patience_counter = 0

    epoch_counter = 0
    total = sum(len(dl) for dl in train_dls)

    # Main loop
    while epoch_counter < n_epochs:
        dl_iters = [iter(dl) for dl in train_dls]
        dl_idx = list(range(len(dl_iters)))
        finished = [0] * len(dl_iters)
        i = 0
        with tqdm(total=total, desc="Training") as pbar:
            while sum(finished) < len(dl_iters):
                random.shuffle(dl_idx)
                for d in dl_idx:
                    domain_dl = dl_iters[d]
                    batches = []
                    try:
                        for j in range(gradient_accumulation):
                            batches.append(next(domain_dl))
                    except StopIteration:
                        finished[d] = 1
                        if len(batches) == 0:
                            continue
                    optimizer.zero_grad()
                    for batch in batches:
                        model.train()
                        batch = tuple(t.to(device) for t in batch)
                        input_ids = batch[0]
                        masks = batch[1]
                        labels = batch[2]
                        # Testing with random domains to see if any effect
                        #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
                        domains = batch[3]

                        loss, logits, alpha = model(input_ids,
                                                    attention_mask=masks,
                                                    domains=domains,
                                                    labels=labels,
                                                    ret_alpha=True)
                        loss = loss.mean() / gradient_accumulation
                        if i % log_interval == 0:
                            # wandb.log({
                            #     "Loss": loss.item(),
                            #     "alpha0": alpha[:,0].cpu(),
                            #     "alpha1": alpha[:, 1].cpu(),
                            #     "alpha2": alpha[:, 2].cpu(),
                            #     "alpha_shared": alpha[:, 3].cpu()
                            # })
                            wandb.log({"Loss": loss.item()})

                        loss.backward()
                        i += 1
                        pbar.update(1)

                    optimizer.step()
                    if scheduler is not None:
                        scheduler.step()

        gc.collect()

        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
        print(f"Validation f1: {F1}")

        #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')

        # Saving the best model and early stopping
        #if val_loss < best_loss:
        if F1 > best_f1:
            best_model = model.state_dict()
            #best_loss = val_loss
            best_f1 = F1
            #wandb.run.summary['best_validation_loss'] = best_loss
            torch.save(
                model.state_dict(),
                f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth'
            )
            patience_counter = 0
            # Log to wandb
            wandb.log({
                'Validation accuracy': acc,
                'Validation Precision': P,
                'Validation Recall': R,
                'Validation F1': F1,
                'Validation loss': val_loss
            })
        else:
            patience_counter += 1
            # Stop training once we have lost patience
            if patience_counter == patience:
                break

        gc.collect()
        epoch_counter += 1
def main():
    # Training settings and hyper-parameters
    parser = argparse.ArgumentParser(
        description='Data Source (Batch) Prediction for Cell Lines')

    # Dataset parameters ######################################################
    # Pre-processing for dataframes
    parser.add_argument('--rnaseq_scaling',
                        type=str,
                        default='std',
                        help='scaling method for RNA sequence',
                        choices=SCALING_METHODS)

    # Feature usage and partitioning settings
    parser.add_argument('--rnaseq_feature_usage',
                        type=str,
                        default='combat',
                        help='RNA sequence data used',
                        choices=[
                            'source_scale',
                            'combat',
                        ])
    parser.add_argument('--validation_ratio',
                        type=float,
                        default=0.2,
                        help='ratio for validation dataset')

    # Network configuration ###################################################
    parser.add_argument('--layer_dim',
                        type=int,
                        default=256,
                        help='dimension of layers for RNA sequence')
    parser.add_argument('--num_layers',
                        type=int,
                        default=4,
                        help='number of layers for RNA sequence')

    # Training and validation parameters ######################################
    parser.add_argument('--opt',
                        type=str,
                        default='SGD',
                        help='optimizer for data source prediction',
                        choices=['SGD', 'RMSprop', 'Adam'])
    parser.add_argument('--lr',
                        type=float,
                        default=1e-2,
                        help='learning rate for data source prediction')

    # Starting epoch for validation
    parser.add_argument('--val_start_epoch',
                        type=int,
                        default=0,
                        help='starting epoch for data source prediction')

    # Early stopping based on data source prediction accuracy
    parser.add_argument('--early_stop_patience',
                        type=int,
                        default=50,
                        help='patience for early stopping based on data '
                        'source prediction accuracy')

    # Global/shared training parameters
    parser.add_argument('--l2_regularization',
                        type=float,
                        default=0.,
                        help='L2 regularization for nn weights')
    parser.add_argument('--lr_decay_factor',
                        type=float,
                        default=0.98,
                        help='decay factor for learning rate')
    parser.add_argument('--trn_batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training')
    parser.add_argument('--val_batch_size',
                        type=int,
                        default=256,
                        help='input batch size for validation')
    parser.add_argument('--max_num_batches',
                        type=int,
                        default=10000,
                        help='maximum number of batches per epoch')
    parser.add_argument('--max_num_epochs',
                        type=int,
                        default=1000,
                        help='maximum number of epochs')

    # Miscellaneous settings ##################################################
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--rand_state',
                        type=int,
                        default=0,
                        help='random state of numpy/sklearn/pytorch')

    args = parser.parse_args()
    print('Training Arguments:\n' + json.dumps(vars(args), indent=4))

    # Setting up random seed for reproducible and deterministic results
    seed_random_state(args.rand_state)

    # Computation device config (cuda or cpu)
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    # Data loaders for training/validation ####################################
    dataloader_kwargs = {
        'timeout': 1,
        'shuffle': 'True',
        # 'num_workers': multiprocessing.cpu_count() if use_cuda else 0,
        'num_workers': NUM_WORKER if use_cuda else 0,
        'pin_memory': True if use_cuda else False,
    }

    # Drug response dataloaders for training/validation
    cl_clf_dataset_kwargs = {
        'data_root': DATA_ROOT,
        'rand_state': args.rand_state,
        'summary': False,
        'int_dtype': np.int8,
        'float_dtype': np.float16,
        'output_dtype': np.float32,
        'rnaseq_scaling': args.rnaseq_scaling,
        'predict_target': 'source',
        'rnaseq_feature_usage': args.rnaseq_feature_usage,
        'validation_ratio': args.validation_ratio,
    }

    cl_clf_trn_loader = torch.utils.data.DataLoader(
        CLClassDataset(training=True, **cl_clf_dataset_kwargs),
        batch_size=args.trn_batch_size,
        **dataloader_kwargs)

    cl_clf_val_loader = torch.utils.data.DataLoader(
        CLClassDataset(training=False, **cl_clf_dataset_kwargs),
        batch_size=args.val_batch_size,
        **dataloader_kwargs)

    # Constructing and initializing neural networks ###########################
    net = nn.Sequential()

    prev_dim = cl_clf_trn_loader.dataset.rnaseq_dim
    for label in ['site', 'type', 'category']:
        prev_dim += len(get_label_dict(DATA_ROOT, '%s_dict.txt' % label))

    # net.add_module('dense_%d' % 0, nn.Linear(prev_dim, args.layer_dim))

    for i in range(args.num_layers):
        # net.add_module('residual_block_%d' % i,
        #                ResBlock(layer_dim=args.layer_dim,
        #                         num_layers=2,
        #                         dropout=0.))

        net.add_module('dense_%d' % i, nn.Linear(prev_dim, args.layer_dim))
        net.add_module('dropout_%d' % i, nn.Dropout(0.2))
        prev_dim = args.layer_dim
        net.add_module('relu_%d' % i, nn.ReLU())

    num_data_src = len(get_label_dict(DATA_ROOT, 'data_src_dict.txt'))
    net.add_module('dense', nn.Linear(args.layer_dim, num_data_src))
    net.add_module('logsoftmax', nn.LogSoftmax(dim=1))
    net.apply(basic_weight_init)
    net.to(device)

    print(net)

    # Optimizers, learning rate decay, and miscellaneous ######################
    opt = get_optimizer(opt_type=args.opt,
                        networks=net,
                        learning_rate=args.lr,
                        l2_regularization=args.l2_regularization)
    lr_decay = LambdaLR(optimizer=opt,
                        lr_lambda=lambda e: args.lr_decay_factor**e)

    # Training/validation loops ###############################################
    val_acc = []
    best_acc = 0.
    patience = 0
    start_time = time.time()

    for epoch in range(args.max_num_epochs):

        print('=' * 80 + '\nTraining Epoch %3i:' % (epoch + 1))
        epoch_start_time = time.time()

        lr_decay.step(epoch)

        # Training loop #######################################################
        net.train()

        for batch_idx, (rnaseq, data_src, cl_site, cl_type, cl_category) \
                in enumerate(cl_clf_trn_loader):

            if batch_idx >= args.max_num_batches:
                break

            rnaseq, data_src, cl_site, cl_type, cl_category = \
                rnaseq.to(device), data_src.to(device), cl_site.to(device), \
                cl_type.to(device), cl_category.to(device)

            net.zero_grad()

            out_data_src = net(
                torch.cat((rnaseq, cl_site, cl_type, cl_category), dim=1))

            F.nll_loss(input=out_data_src, target=data_src).backward()

            opt.step()

        # Validation loop #####################################################
        net.eval()

        correct_data_src = 0
        with torch.no_grad():
            for rnaseq, data_src, cl_site, cl_type, cl_category \
                    in cl_clf_val_loader:

                rnaseq, data_src, cl_site, cl_type, cl_category = \
                    rnaseq.to(device), data_src.to(device), \
                    cl_site.to(device), cl_type.to(device), \
                    cl_category.to(device)

                out_data_src = net(
                    torch.cat((rnaseq, cl_site, cl_type, cl_category), dim=1))

                pred_data_src = out_data_src.max(1, keepdim=True)[1]

                # print(data_src)
                # print(pred_data_src)

                correct_data_src += pred_data_src.eq(
                    data_src.view_as(pred_data_src)).sum().item()

        data_src_acc = 100. * correct_data_src / len(cl_clf_val_loader.dataset)

        print(
            '\tCell Line Data Source (Batch) Prediction Accuracy: %5.2f%%; ' %
            data_src_acc)

        # Results recording and early stopping
        val_acc.append(data_src_acc)

        if data_src_acc > best_acc:
            patience = 0
            best_acc = data_src_acc
        else:
            patience += 1
        if patience >= args.early_stop_patience:
            print('Validation accuracy does not improve for %d epochs ... '
                  'invoking early stopping.' % patience)
            break

        print('Epoch Running Time: %.1f Seconds.' %
              (time.time() - epoch_start_time))

    print('Program Running Time: %.1f Seconds.' % (time.time() - start_time))
    print('Best Cell Line Data Source (Batch) Prediction Accuracy: %5.2f%%; ' %
          np.amax(val_acc))

    import matplotlib.pyplot as plt
    x = range(1, len(val_acc) + 1)
    plt.plot(x, val_acc)
    plt.xlabel('Epochs')
    plt.ylabel('Value Accuracy')
    plt.title('Value Accuracy over Training')
    plt.show()
Exemple #30
0
def train(model,
          state,
          path,
          annotations,
          val_path,
          val_annotations,
          resize,
          max_size,
          jitter,
          batch_size,
          iterations,
          val_iterations,
          mixed_precision,
          lr,
          warmup,
          milestones,
          gamma,
          is_master=True,
          world=1,
          use_dali=True,
          verbose=True,
          metrics_url=None,
          logdir=None):
    'Train the model on the given dataset'

    # Prepare model
    nn_model = model
    stride = model.stride

    model = convert_fixedbn_model(model)
    if torch.cuda.is_available():
        model = model.cuda()

    # Setup optimizer and schedule
    optimizer = SGD(model.parameters(),
                    lr=lr,
                    weight_decay=0.0001,
                    momentum=0.9)

    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level='O2' if mixed_precision else 'O0',
        keep_batchnorm_fp32=True,
        loss_scale=128.0,
        verbosity=is_master)

    if world > 1:
        model = DistributedDataParallel(model)
    model.train()

    if 'optimizer' in state:
        optimizer.load_state_dict(state['optimizer'])

    def schedule(train_iter):
        if warmup and train_iter <= warmup:
            return 0.9 * train_iter / warmup + 0.1
        return gamma**len([m for m in milestones if m <= train_iter])

    scheduler = LambdaLR(optimizer, schedule)

    # Prepare dataset
    if verbose: print('Preparing dataset...')
    data_iterator = (DaliDataIterator if use_dali else DataIterator)(
        path,
        jitter,
        max_size,
        batch_size,
        stride,
        world,
        annotations,
        training=True)
    if verbose: print(data_iterator)

    if verbose:
        print('    device: {} {}'.format(
            world, 'cpu' if not torch.cuda.is_available() else
            'gpu' if world == 1 else 'gpus'))
        print('    batch: {}, precision: {}'.format(
            batch_size, 'mixed' if mixed_precision else 'full'))
        print('Training model for {} iterations...'.format(iterations))

    # Create TensorBoard writer
    if logdir is not None:
        from tensorboardX import SummaryWriter
        if is_master and verbose:
            print('Writing TensorBoard logs to: {}'.format(logdir))
        writer = SummaryWriter(logdir=logdir)

    profiler = Profiler(['train', 'fw', 'bw'])
    iteration = state.get('iteration', 0)
    while iteration < iterations:
        cls_losses, box_losses = [], []
        for i, (data, target) in enumerate(data_iterator):

            # Forward pass
            profiler.start('fw')

            optimizer.zero_grad()
            cls_loss, box_loss = model([data, target])
            del data
            profiler.stop('fw')

            # Backward pass
            profiler.start('bw')
            with amp.scale_loss(cls_loss + box_loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()

            scheduler.step(iteration)

            # Reduce all losses
            cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean(
            ).clone()
            if world > 1:
                torch.distributed.all_reduce(cls_loss)
                torch.distributed.all_reduce(box_loss)
                cls_loss /= world
                box_loss /= world
            if is_master:
                cls_losses.append(cls_loss)
                box_losses.append(box_loss)

            if is_master and not isfinite(cls_loss + box_loss):
                raise RuntimeError('Loss is diverging!\n{}'.format(
                    'Try lowering the learning rate.'))

            del cls_loss, box_loss
            profiler.stop('bw')

            iteration += 1
            profiler.bump('train')
            if is_master and (profiler.totals['train'] > 60
                              or iteration == iterations):
                focal_loss = torch.stack(list(cls_losses)).mean().item()
                box_loss = torch.stack(list(box_losses)).mean().item()
                learning_rate = optimizer.param_groups[0]['lr']
                if verbose:
                    msg = '[{:{len}}/{}]'.format(iteration,
                                                 iterations,
                                                 len=len(str(iterations)))
                    msg += ' focal loss: {:.3f}'.format(focal_loss)
                    msg += ', box loss: {:.3f}'.format(box_loss)
                    msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'],
                                                       batch_size)
                    msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format(
                        profiler.means['fw'], profiler.means['bw'])
                    msg += ', {:.1f} im/s'.format(batch_size /
                                                  profiler.means['train'])
                    msg += ', lr: {:.2g}'.format(learning_rate)
                    print(msg, flush=True)

                if logdir is not None:
                    writer.add_scalar('Train/Loss/Focal', focal_loss,
                                      iteration)
                    writer.add_scalar('Train/Loss/Box', box_loss, iteration)
                    writer.add_scalar('learning_rate', learning_rate,
                                      iteration)
                    del box_loss, focal_loss

                if metrics_url:
                    post_metrics(
                        metrics_url, {
                            'focal loss': mean(cls_losses),
                            'box loss': mean(box_losses),
                            'im_s': batch_size / profiler.means['train'],
                            'lr': learning_rate
                        })

                # Save model weights
                state.update({
                    'iteration': iteration,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                })
                with ignore_sigint():
                    nn_model.save(state)

                profiler.reset()
                del cls_losses[:], box_losses[:]

            if val_annotations and (iteration == iterations
                                    or iteration % val_iterations == 0):
                infer(model,
                      val_path,
                      None,
                      resize,
                      max_size,
                      batch_size,
                      annotations=val_annotations,
                      mixed_precision=mixed_precision,
                      is_master=is_master,
                      world=world,
                      use_dali=use_dali,
                      is_validation=True,
                      verbose=False,
                      logdir=logdir,
                      iteration=iteration)
                model.train()

            if iteration == iterations:
                break

    if logdir is not None:
        writer.close()
Exemple #31
0
class Learner(object):
    def __init__(self, model,
                 ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4,
                 alpha=0.99, epsilon=1e-5, number_updates=int(1e6), lrschedule='linear',
                 use_actor_critic=False, rew_loss_coef=0.0, st_loss_coef=0.0,
                 subtree_loss_coef=0.0,
                 nsteps=5, nenvs=1,
                 tree_depth=0):
        self.max_grad_norm = max_grad_norm
        self.use_actor_critic = use_actor_critic
        self.use_reward_loss = model.predict_rewards and rew_loss_coef > 0
        self.rew_loss_coef = rew_loss_coef
        self.use_st_loss = st_loss_coef > 0 and tree_depth > 0
        self.st_loss_coef = st_loss_coef
        self.subtree_loss_coef = subtree_loss_coef
        self.use_subtree_loss = subtree_loss_coef > 0
        self.model = model
        self.nsteps = nsteps
        self.nenvs = nenvs
        self.batch_size = nsteps * nenvs
        self.num_actions = model.num_actions
        self.tree_depth = tree_depth

        if USE_CUDA:
            self.model = self.model.cuda()

        if not self.use_actor_critic:
            self.target_model = copy.deepcopy(self.model)

            if USE_CUDA:
                self.target_model = self.target_model.cuda()

        self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=lr, alpha=alpha, eps=epsilon)

        if lrschedule == "linear":
            self.scheduler = LambdaLR(self.optimizer, lambda step: 1.0 - (step / number_updates))
        elif lrschedule == "constant":
            self.scheduler = LambdaLR(self.optimizer, lambda step: 1.0)
        else:
            raise ValueError("lrschedule should be 'linear' or 'constant'")

        self.step = self.model.step
        if self.use_actor_critic:
            self.value = self.model.value
            self.ent_coef = ent_coef
            self.vf_coef = vf_coef
        else:
            self.value = self.target_model.value

    def train(self, obs, next_obs, returns, rewards, masks, actions, values):
        """
        :param obs: [batch_size x height x width x channels] observations in NHWC
        :param next_obs: [batch_size x height x width x channels] one-step next states
        :param returns: [batch_size] n-step discounted returns with bootstrapped value
        :param rewards: [batch_size] 1-step rewards
        :param masks: [batch_size] boolean episode termination mask
        :param actions: [batch_size] actions taken
        :param values: [batch_size] predicted state values
        """

        # compute the sequences we need to get back reward predictions
        action_sequences, reward_sequences, sequence_mask = build_sequences(
            [torch.from_numpy(actions), torch.from_numpy(rewards)], self.nenvs, self.nsteps, self.tree_depth, return_mask=True)
        action_sequences = cudify(action_sequences.long().squeeze(-1))
        reward_sequences = make_variable(reward_sequences.squeeze(-1))
        sequence_mask = make_variable(sequence_mask.squeeze(-1))

        Q, V, tree_result = self.model(obs)

        actions = make_variable(torch.from_numpy(actions).long(), requires_grad=False)
        returns = make_variable(torch.from_numpy(returns), requires_grad=False)

        policy_loss, value_loss, reward_loss, state_loss, subtree_loss_np, policy_entropy = 0, 0, 0, 0, 0, 0
        if self.use_actor_critic:
            values = make_variable(torch.from_numpy(values), requires_grad=False)
            advantages = returns - values
            probs = F.softmax(Q, dim=-1)
            log_probs = F.log_softmax(Q, dim=-1)
            log_probs_taken = log_probs.gather(1, actions.unsqueeze(1)).squeeze()
            pg_loss = -torch.mean(log_probs_taken * advantages.squeeze())
            vf_loss = F.mse_loss(V, returns)
            entropy = -torch.mean(torch.sum(probs * log_probs, 1))
            loss = pg_loss + self.vf_coef * vf_loss - self.ent_coef * entropy

            policy_loss = pg_loss.data.cpu().numpy()
            value_loss = vf_loss.data.cpu().numpy()
            policy_entropy = entropy.data.cpu().numpy()
        else:
            Q_taken = Q.gather(1, actions.unsqueeze(1)).squeeze()
            bellman_loss = F.mse_loss(Q_taken, returns)
            loss = bellman_loss
            value_loss = bellman_loss.data.cpu().numpy()

        if self.use_reward_loss:
            r_taken = get_paths(tree_result["rewards"], action_sequences, self.batch_size, self.num_actions)
            rew_loss = F.mse_loss(torch.cat(r_taken, 1), reward_sequences, reduce=False)
            rew_loss = torch.sum(rew_loss * sequence_mask) / sequence_mask.sum()
            loss = loss + rew_loss * self.rew_loss_coef
            reward_loss = rew_loss.data.cpu().numpy()

        if self.use_st_loss:
            st_embeddings = tree_result["embeddings"][0]
            st_targets, st_mask = build_sequences([st_embeddings.data], self.nenvs, self.nsteps, self.tree_depth, return_mask=True, offset=1)
            st_targets = make_variable(st_targets.view(self.batch_size, -1))
            st_mask = make_variable(st_mask.view(self.batch_size, -1))

            st_taken = get_paths(tree_result["embeddings"][1:], action_sequences, self.batch_size, self.num_actions)

            st_taken_cat = torch.cat(st_taken, 1)

            st_loss = F.mse_loss(st_taken_cat, st_targets, reduce=False)
            st_loss = torch.sum(st_loss * st_mask) / st_mask.sum()

            state_loss = st_loss.data.cpu().numpy()
            loss = loss + st_loss * self.st_loss_coef

        if self.use_subtree_loss:
            subtree_taken = get_subtree(tree_result["values"], action_sequences, self.batch_size, self.num_actions)
            target_subtrees = tree_result["values"][0:-1]
            subtree_taken_clip = time_shift_tree(subtree_taken, self.nenvs, self.nsteps, -1)
            target_subtrees_clip = time_shift_tree(target_subtrees, self.nenvs, self.nsteps, 1)

            subtree_loss = [(s_taken - s_target).pow(2).mean() for (s_taken, s_target) in zip(subtree_taken_clip, target_subtrees_clip)]
            subtree_loss = sum(subtree_loss)
            subtree_loss_np = subtree_loss.data.cpu().numpy()

            loss = loss + subtree_loss * self.subtree_loss_coef

        self.scheduler.step()
        self.optimizer.zero_grad()
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm(self.model.parameters(), self.max_grad_norm)
        self.optimizer.step()

        return policy_loss, value_loss, reward_loss, state_loss, subtree_loss_np, policy_entropy, grad_norm