def train(self):
        for epoch in range(self.config["epochs"] - self.done_epochs):
            self.logger.record(f'Epoch {epoch+1}/{self.config["epochs"]}',
                               mode='train')
            train_meter = common.AverageMeter()

            for idx, batch in enumerate(self.train_loader):
                train_metrics = self.train_one_step(batch)
                train_meter.add(train_metrics)
                wandb.log({"Train loss": train_meter.return_metrics()["loss"]})
                common.progress_bar(progress=(idx + 1) /
                                    len(self.train_loader),
                                    status=train_meter.return_msg())

            common.progress_bar(progress=1.0, status=train_meter.return_msg())
            self.logger.write(train_meter.return_msg(), mode='train')
            self.adjust_learning_rate(epoch + 1)
            wandb.log({
                "Train accuracy":
                train_meter.return_metrics()["accuracy"],
                "Train F1 score":
                train_meter.return_metrics()["f1"],
                "Epoch":
                epoch + 1
            })

            if (epoch + 1) % self.config["eval_every"] == 0:
                self.logger.record(f'Epoch {epoch+1}/{self.config["epochs"]}',
                                   mode='val')
                val_meter = common.AverageMeter()
                for idx, batch in enumerate(self.val_loader):
                    val_metrics = self.validate_one_step(batch)
                    val_meter.add(val_metrics)
                    common.progress_bar(progress=(idx + 1) /
                                        len(self.val_loader),
                                        status=val_meter.return_msg())

                common.progress_bar(progress=1.0,
                                    status=val_meter.return_msg())
                self.logger.write(val_meter.return_msg(), mode='val')
                wandb.log({
                    "Validation loss":
                    val_meter.return_metrics()["loss"],
                    "Validation accuracy":
                    val_meter.return_metrics()["accuracy"],
                    "Validation F1 score":
                    val_meter.return_metrics()["f1"],
                    "Epoch":
                    epoch + 1
                })

                if val_meter.return_metrics()["loss"] < self.best_val_loss:
                    self.best_val_loss = val_meter.return_metrics()["loss"]
                    self.save_model()

        print("\n\n")
        self.logger.record("Finished training! Generating test predictions...",
                           mode='info')
        self.get_test_predictions(self.config.get("polarize_predictions",
                                                  True))
Beispiel #2
0
def batch_loop(net,
               nc,
               optimizers,
               train_loader,
               working_device,
               criterion,
               amp_flag,
               gamma,
               gamma_rate,
               temp_func=None,
               gamma_target=0.0,
               epoch=0):
    """

    :param net: Input network module
    :param nc: Network Config class
    :param optimizers:  a list of optimizers
    :param train_loader: The training dataset loader
    :param working_device: using amp boolean
    :param criterion: The loss criterion
    :param amp_flag: current working device
    :param gamma: the of weights loss term
    :param gamma_rate:  the target weights compression power factor
    :param temp_func: The gumbel softmax temperature function
    :param gamma_target: the target weights compression
    :param epoch: Epoch index
    :return: None
    """
    n = len(train_loader)
    i = 0
    t = None
    loss_meter = common.AverageMeter()
    accuracy_meter = common.AverageMeter()

    with tqdm(total=n) as pbar:
        prefetcher = DataPreFetcher(train_loader)
        image, label = prefetcher.next()
        while image is not None:
            if temp_func is not None:
                t = temp_func(epoch * n + i)
                nc.set_temperature(t)
            correct, total, loss_value = batch_step(pbar, net, image,
                                                    optimizers, label,
                                                    criterion, gamma,
                                                    gamma_target, gamma_rate,
                                                    amp_flag, working_device)
            loss_meter.update(loss_value)
            accuracy_meter.update(correct, n=total)
            i += 1
            image, label = prefetcher.next()
    torch.cuda.synchronize()
    return loss_meter.avg, 100 * accuracy_meter.avg, t
Beispiel #3
0
def train(bg_loader, fg_loader, mask_loader, vgg16_net, deblend_net, Inp_net,
          Dis_net, deblend_optimizer, Inp_optimizer, Dis_optimizer, epoch):
    B_loss = common.AverageMeter()
    I_loss = common.AverageMeter()
    D_loss = common.AverageMeter()

    deblend_net.train()
    Inp_net.train()
    Dis_net.train()

    # real means original
    # fake means deblended
    real_label = torch.ones(args['batch_size'])
    fake_label = torch.zeros(args['batch_size'])

    real_label = autograd.Variable(real_label).cuda()
    fake_label = autograd.Variable(fake_label).cuda()

    critic_criterion = nn.BCELoss().cuda()
    mse_criterion = nn.MSELoss().cuda()

    bg_iter = iter(bg_loader)
    fg_iter = iter(fg_loader)

    batch_count = 0
    total_count = args['num_training']
    while batch_count < total_count:

        batch_count += 1

        if random.random() > args['fgbg_swap_rate']:
            x3 = next(bg_iter)
            x4 = next(fg_iter)
        else:
            x3 = next(fg_iter)
            x4 = next(bg_iter)

        if x3.size(0) != x4.size(0):
            break
        #############################################################################
        #
        # image systhesis
        #
        #############################################################################
        mask = torch.zeros(args['batch_size'], 3, args['input_resolution'],
                           args['input_resolution'])
        for b in range(0, args['batch_size']):
            mask_num = mask_loader.size()[0]
            mask_id_1 = random.randint(0, mask_num - 1)
            mask_id_2 = random.randint(0, mask_num - 1)
            mask[b] = torch.clamp(mask_loader[mask_id_1] +
                                  mask_loader[mask_id_2],
                                  min=0,
                                  max=1.0)
        x, x1, x2 = blend.blend_gauss_part(x3, x4, mask)

        if epoch < 20:
            scale_lower = 1.0 - 0.5 * (epoch / 20.0)
            th_upper = 0.5 + 0.5 * (epoch / 20.0)
        else:
            scale_lower = 0.5
            th_upper = 1.0

        scale_ratio = random.uniform(scale_lower, 1.0)
        th_ratio = random.uniform(0.5, th_upper)

        partial_ratio = random.random()
        # print(partial_ratio, args['partial_swap_ratio'])
        if partial_ratio < args['partial_swap_ratio']:
            edge_x1 = torch.ones(x.size(0), 1, x.size(2), x.size(3))
            edge_x2 = blend.H_map(x2, e_type='canny', th_ratio=th_ratio)
            edge_x2 = blend.Partial_Map(edge_x2,
                                        size_ratio=1.0,
                                        scale_ratio=scale_ratio)
        elif partial_ratio > 2.0 * args['partial_swap_ratio']:
            edge_x1 = blend.H_map(x1, e_type='canny', th_ratio=th_ratio)
            edge_x1 = blend.Partial_Map(edge_x1,
                                        size_ratio=1.0,
                                        scale_ratio=scale_ratio)
            edge_x2 = torch.ones(x.size(0), 1, x.size(2), x.size(3))
        else:
            edge_x1 = blend.H_map(x1, e_type='canny', th_ratio=th_ratio)
            edge_x1 = blend.Partial_Map(edge_x1,
                                        size_ratio=1.0,
                                        scale_ratio=scale_ratio)
            edge_x2 = blend.H_map(x2, e_type='canny', th_ratio=th_ratio)
            edge_x2 = blend.Partial_Map(edge_x2,
                                        size_ratio=1.0,
                                        scale_ratio=scale_ratio)
        edge_x1, edge_x2 = blend.Edge_Sparse(edge_x1, edge_x2, mask)

        x1, x2, x3, x4, x, mask, edge_x1, edge_x2 = [
            autograd.Variable(z).cuda()
            for z in (x1, x2, x3, x4, x, mask, edge_x1, edge_x2)
        ]

        #############################################################################
        #
        # train deblend_net
        #
        #############################################################################
        deblend_optimizer.zero_grad()

        y1, y2, Rmap = deblend_net(x, edge_x1, edge_x2)

        # train with vgg16
        vgg16_loss_2 = calc_vgg16_loss(vgg16_net, mse_criterion, x2, y2)
        vgg16_loss_3 = calc_vgg16_loss(vgg16_net, mse_criterion, x1, y1)
        error_vgg16 = args['weight_vgg16'] * (vgg16_loss_2 + vgg16_loss_3)

        # train with pixel
        pixel_loss_2 = mse_criterion(y2, x2)
        pixel_loss_3 = mse_criterion(y1, x1)
        pixel_loss_4 = mse_criterion(Rmap, mask.mean(1))
        error_pixel = args['weight_recon'] * (pixel_loss_2 + pixel_loss_3 +
                                              pixel_loss_4)

        deblend_loss = error_vgg16 + error_pixel
        deblend_loss.backward(retain_graph=True)

        deblend_optimizer.step()

        #############################################################################
        #
        # train inp_net
        #
        #############################################################################
        Inp_optimizer.zero_grad()

        # y1, y2 = y1.detach(), y2.detach()

        gray_y2 = blend.ToGray(y2, True)
        y1_1 = torch.cat((y1, gray_y2), 1)

        y3_f = Inp_net(y1_1)
        ap_mask = blend.ToGray(Rmap.expand_as(y3_f), True)
        y3_f = y3_f * ap_mask
        y3 = y3_f + y1

        # train with critic
        fake_pred = Dis_net(y3_f)
        error_dis = args['weight_iadv'] * critic_criterion(
            fake_pred, real_label)

        # train with vgg16
        vgg16_loss_inp = calc_vgg16_loss(vgg16_net, mse_criterion, x3, y3)
        error_vgg16_inp = args['weight_ipec'] * vgg16_loss_inp

        # train with pixel
        pixel_loss_inp = mse_criterion(y3, x3)
        error_pixel_inp = args['weight_ipix'] * pixel_loss_inp

        inp_loss = error_dis + error_vgg16_inp + error_pixel_inp
        inp_loss.backward()

        Inp_optimizer.step()

        #############################################################################
        #
        # train Dis_net refinement
        #
        #############################################################################
        Dis_optimizer.zero_grad()

        # train with real
        x3_f = x3 - y1
        real_pred = Dis_net(x3_f.detach())  # t
        dis_real = critic_criterion(real_pred, real_label)

        # train with fake
        fake_pred = Dis_net(y3_f.detach())
        dis_fake = critic_criterion(fake_pred, fake_label)

        dis_loss = dis_real + dis_fake
        dis_loss.backward()

        Dis_optimizer.step()

        #############################################################################
        #
        # update loss
        #
        #############################################################################
        B_loss.update(deblend_loss.data[0], x1.size(0))
        I_loss.update(inp_loss.data[0], x1.size(0))
        D_loss.update(dis_loss.data[0], x1.size(0))

        if batch_count % args['print_frequency'] == 1:
            print(
                '    epoch: {}   train: {}/{}\t'
                '    deblend_loss: {deblend_loss.avg:.4f}\t'
                '    inp_loss: {inp_loss.avg:.4f}   dis_loss: {dis_loss.avg:.4f}'
                .format(epoch,
                        batch_count,
                        total_count,
                        deblend_loss=B_loss,
                        inp_loss=I_loss,
                        dis_loss=D_loss))
    def train(self, run_id=0):
        print()
        for epoch in range(self.done_epochs, self.config['epochs'] + 1):
            train_meter = common.AverageMeter()
            val_meter = common.AverageMeter()
            self.train_data, self.val_data = self.train_data.to(
                self.device), self.val_data.to(self.device)
            self.logger.record('Epoch [{:3d}/{}]'.format(
                epoch, self.config['epochs']),
                               mode='train')
            if self.scheduler is not None:
                self.adjust_learning_rate(epoch + 1)

            for idx in range(len(self.train_loader)):
                batch = self.train_loader.flow()
                train_metrics = self.train_on_batch(batch)
                wandb.log({
                    'Loss': train_metrics['Loss'],
                    'MAPE': train_metrics['MAPE'],
                    'Epoch': epoch
                })
                train_meter.add(train_metrics)
                common.progress_bar(progress=idx / len(self.train_loader),
                                    status=train_meter.return_msg())

            common.progress_bar(progress=1, status=train_meter.return_msg())
            self.logger.write(train_meter.return_msg(), mode='train')
            wandb.log({
                'Learning rate': self.optim.param_groups[0]['lr'],
                'Epoch': epoch
            })

            # Save state
            self.save_state(epoch)

            # Validation
            if epoch % self.config['eval_every'] == 0:
                self.logger.record('Epoch [{:3d}/{}]'.format(
                    epoch, self.config['epochs']),
                                   mode='val')
                val_metrics, forecast = self.validate()
                self.compare_predictions(forecast)
                val_meter.add(val_metrics)

                self.logger.record(val_meter.return_msg(), mode='val')
                val_metrics = val_meter.return_metrics()
                wandb.log({
                    'Validation loss': val_metrics['Loss'],
                    'Validation MAPE': val_metrics['MAPE'],
                    'Epoch': epoch
                })

                if val_metrics['Loss'] < self.best_val:
                    self.best_val = val_metrics['Loss']
                    self.save_model()

                if val_metrics['MAPE'] < self.best_mape:
                    self.best_mape = val_metrics['MAPE']

        self.stack_breakdown(run_id)
        self.autoregressive_forecast(run_id)
        print('\n\n[INFO] Training complete!')
        return self.best_val, self.best_mape