示例#1
0
def train_model(model,
                model_test,
                criterion,
                optimizer,
                scheduler,
                num_epochs=25):
    since = time.time()

    #best_model_wts = model.state_dict()
    #best_acc = 0.0
    warm_up = 0.1  # We start from the 0.1*lrRate
    warm_iteration = round(dataset_sizes['satellite'] /
                           opt.batchsize) * opt.warm_epoch  # first 5 epoch

    for epoch in range(num_epochs - start_epoch):
        epoch = epoch + start_epoch
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0
            running_corrects2 = 0.0
            running_corrects3 = 0.0
            # Iterate over data.
            for data, data2, data3, data4 in zip(dataloaders['satellite'],
                                                 dataloaders['street'],
                                                 dataloaders['drone'],
                                                 dataloaders['google']):
                # get the inputs
                inputs, labels = data
                inputs2, labels2 = data2
                inputs3, labels3 = data3
                inputs4, labels4 = data4
                now_batch_size, c, h, w = inputs.shape
                if now_batch_size < opt.batchsize:  # skip the last batch
                    continue
                if use_gpu:
                    inputs = Variable(inputs.cuda().detach())
                    inputs2 = Variable(inputs2.cuda().detach())
                    inputs3 = Variable(inputs3.cuda().detach())
                    labels = Variable(labels.cuda().detach())
                    labels2 = Variable(labels2.cuda().detach())
                    labels3 = Variable(labels3.cuda().detach())
                    if opt.extra_Google:
                        inputs4 = Variable(inputs4.cuda().detach())
                        labels4 = Variable(labels4.cuda().detach())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                if phase == 'val':
                    with torch.no_grad():
                        outputs, outputs2 = model(inputs, inputs2)
                else:
                    if opt.views == 2:
                        outputs, outputs2 = model(inputs, inputs2)
                    elif opt.views == 3:
                        if opt.extra_Google:
                            outputs, outputs2, outputs3, outputs4 = model(
                                inputs, inputs2, inputs3, inputs4)
                        else:
                            outputs, outputs2, outputs3 = model(
                                inputs, inputs2, inputs3)
                _, preds = torch.max(outputs.data, 1)
                _, preds2 = torch.max(outputs2.data, 1)

                if opt.views == 2:
                    loss = criterion(outputs, labels) + criterion(
                        outputs2, labels2)
                elif opt.views == 3:
                    _, preds3 = torch.max(outputs3.data, 1)
                    loss = criterion(outputs, labels) + criterion(
                        outputs2, labels2) + criterion(outputs3, labels3)
                    if opt.extra_Google:
                        loss += criterion(outputs4, labels4)
                # backward + optimize only if in training phase
                if epoch < opt.warm_epoch and phase == 'train':
                    warm_up = min(1.0, warm_up + 0.9 / warm_iteration)
                    loss *= warm_up

                if phase == 'train':
                    if fp16:  # we use optimier to backward loss
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    optimizer.step()
                    ##########
                    if opt.moving_avg < 1.0:
                        update_average(model_test, model, opt.moving_avg)

                # statistics
                if int(version[0]) > 0 or int(
                        version[2]
                ) > 3:  # for the new version like 0.4.0, 0.5.0 and 1.0.0
                    running_loss += loss.item() * now_batch_size
                else:  # for the old version like 0.3.0 and 0.3.1
                    running_loss += loss.data[0] * now_batch_size
                running_corrects += float(torch.sum(preds == labels.data))
                running_corrects2 += float(torch.sum(preds2 == labels2.data))
                if opt.views == 3:
                    running_corrects3 += float(
                        torch.sum(preds3 == labels3.data))

            epoch_loss = running_loss / dataset_sizes['satellite']
            epoch_acc = running_corrects / dataset_sizes['satellite']
            epoch_acc2 = running_corrects2 / dataset_sizes['satellite']

            if opt.views == 2:
                print(
                    '{} Loss: {:.4f} Satellite_Acc: {:.4f}  Street_Acc: {:.4f}'
                    .format(phase, epoch_loss, epoch_acc, epoch_acc2))
            elif opt.views == 3:
                epoch_acc3 = running_corrects3 / dataset_sizes['satellite']
                print(
                    '{} Loss: {:.4f} Satellite_Acc: {:.4f}  Street_Acc: {:.4f} Drone_Acc: {:.4f}'
                    .format(phase, epoch_loss, epoch_acc, epoch_acc2,
                            epoch_acc3))

            y_loss[phase].append(epoch_loss)
            y_err[phase].append(1.0 - epoch_acc)
            # deep copy the model
            if phase == 'train':
                scheduler.step()
            last_model_wts = model.state_dict()
            if epoch % 20 == 19:
                save_network(model, opt.name, epoch)
            #draw_curve(epoch)

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    #print('Best val Acc: {:4f}'.format(best_acc))
    #save_network(model_test, opt.name+'adapt', epoch)

    return model
示例#2
0
    def update_log(self, results):
        """
        Updates the internal log, Algorithm.log_dict
        Args:
            - results (dictionary)
        """
        results = self.sanitize_dict(results, to_out_device=False)
        # check all the fields exist
        for field in self.logged_fields:
            assert field in results, f"field {field} missing"
        # compute statistics for the current batch
        batch_log = {}
        with torch.no_grad():
            for m in self.logged_metrics:
                if not self.no_group_logging:
                    group_metrics, group_counts, worst_group_metric = m.compute_group_wise(
                        results['y_pred'],
                        results['y_true'],
                        results['g'],
                        self.grouper.n_groups,
                        return_dict=False)
                    batch_log[f'{self.group_prefix}{m.name}'] = group_metrics
                batch_log[m.agg_metric_field] = m.compute(
                    results['y_pred'], results['y_true'],
                    return_dict=False).item()
            count = results['y_true'].numel()

        # transfer other statistics in the results dictionary
        for field in self.logged_fields:
            if field.startswith(self.group_prefix) and self.no_group_logging:
                continue
            v = results[field]
            if isinstance(v, torch.Tensor) and v.numel() == 1:
                batch_log[field] = v.item()
            else:
                if isinstance(v, torch.Tensor):
                    assert v.numel(
                    ) == self.grouper.n_groups, "Current implementation deals only with group-wise statistics or a single-number statistic"
                    assert field.startswith(self.group_prefix)
                batch_log[field] = v

        # update the log dict with the current batch
        if not self._has_log:  # since it is the first log entry, just save the current log
            self.log_dict = batch_log
            if not self.no_group_logging:
                self.log_dict[self.group_count_field] = group_counts
            self.log_dict[self.count_field] = count
        else:  # take a running average across batches otherwise
            for k, v in batch_log.items():
                if k.startswith(self.group_prefix):
                    if self.no_group_logging:
                        continue
                    self.log_dict[k] = update_average(
                        self.log_dict[k],
                        self.log_dict[self.group_count_field], v, group_counts)
                else:
                    self.log_dict[k] = update_average(
                        self.log_dict[k], self.log_dict[self.count_field], v,
                        count)
            if not self.no_group_logging:
                self.log_dict[self.group_count_field] += group_counts
            self.log_dict[self.count_field] += count
        self._has_log = True
示例#3
0
def train_model(model, model_test, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    # best_model_wts = model.state_dict()
    # best_acc = 0.0
    warm_up = 0.1  # We start from the 0.1*lrRate
    warm_iteration = round(dataset_sizes['satellite'] / opt.batchsize) * opt.warm_epoch  # first 5 epoch

    if opt.arcface:
        criterion_arcface = losses.ArcFaceLoss(num_classes=opt.nclasses, embedding_size=512)
    if opt.cosface:
        criterion_cosface = losses.CosFaceLoss(num_classes=opt.nclasses, embedding_size=512)
    if opt.circle:
        criterion_circle = CircleLoss(m=0.25, gamma=32)  # gamma = 64 may lead to a better result.
    if opt.triplet:
        miner = miners.MultiSimilarityMiner()
        criterion_triplet = losses.TripletMarginLoss(margin=0.3)
    if opt.lifted:
        criterion_lifted = losses.GeneralizedLiftedStructureLoss(neg_margin=1, pos_margin=0)
    if opt.contrast:
        criterion_contrast = losses.ContrastiveLoss(pos_margin=0, neg_margin=1)
    if opt.sphere:
        criterion_sphere = losses.SphereFaceLoss(num_classes=opt.nclasses, embedding_size=512, margin=4)

    for epoch in range(num_epochs - start_epoch):
        epoch = epoch + start_epoch
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0
            running_corrects2 = 0.0
            running_corrects3 = 0.0
            # Iterate over data.
            for data, data2, data3, data4 in zip(dataloaders['satellite'], dataloaders['street'], dataloaders['drone'],
                                                 dataloaders['google']):
                # get the inputs
                inputs, labels = data
                inputs2, labels2 = data2
                inputs3, labels3 = data3
                inputs4, labels4 = data4
                now_batch_size, c, h, w = inputs.shape
                if now_batch_size < opt.batchsize:  # skip the last batch
                    continue
                if use_gpu:
                    inputs = Variable(inputs.cuda().detach())
                    inputs2 = Variable(inputs2.cuda().detach())
                    inputs3 = Variable(inputs3.cuda().detach())
                    labels = Variable(labels.cuda().detach())
                    labels2 = Variable(labels2.cuda().detach())
                    labels3 = Variable(labels3.cuda().detach())
                    if opt.extra_Google:
                        inputs4 = Variable(inputs4.cuda().detach())
                        labels4 = Variable(labels4.cuda().detach())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                if phase == 'val':
                    with torch.no_grad():
                        outputs, outputs2 = model(inputs, inputs2)
                else:
                    if opt.views == 2:
                        outputs, outputs2 = model(inputs, inputs2)
                    elif opt.views == 3:
                        if opt.extra_Google:
                            outputs, outputs2, outputs3, outputs4 = model(inputs, inputs2, inputs3, inputs4)
                        else:
                            outputs, outputs2, outputs3 = model(inputs, inputs2, inputs3)

                return_feature = opt.arcface or opt.cosface or opt.circle or opt.triplet or opt.contrast or opt.lifted or opt.sphere

                if opt.views == 2:
                    _, preds = torch.max(outputs.data, 1)
                    _, preds2 = torch.max(outputs2.data, 1)
                    loss = criterion(outputs, labels) + criterion(outputs2, labels2)
                elif opt.views == 3:
                    if return_feature:
                        logits, ff = outputs
                        logits2, ff2 = outputs2
                        logits3, ff3 = outputs3
                        fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                        fnorm2 = torch.norm(ff2, p=2, dim=1, keepdim=True)
                        fnorm3 = torch.norm(ff3, p=2, dim=1, keepdim=True)
                        ff = ff.div(fnorm.expand_as(ff))  # 8*512,tensor
                        ff2 = ff2.div(fnorm2.expand_as(ff2))
                        ff3 = ff3.div(fnorm3.expand_as(ff3))
                        loss = criterion(logits, labels) + criterion(logits2, labels2) + criterion(logits3, labels3)
                        _, preds = torch.max(logits.data, 1)
                        _, preds2 = torch.max(logits2.data, 1)
                        _, preds3 = torch.max(logits3.data, 1)
                        # Multiple perspectives are combined to calculate losses, please join ''--loss_merge'' in run.sh
                        if opt.loss_merge:
                            ff_all = torch.cat((ff, ff2, ff3), dim=0)
                            labels_all = torch.cat((labels, labels2, labels3), dim=0)
                        if opt.extra_Google:
                            logits4, ff4 = outputs4
                            fnorm4 = torch.norm(ff4, p=2, dim=1, keepdim=True)
                            ff4 = ff4.div(fnorm4.expand_as(ff4))
                            loss = criterion(logits, labels) + criterion(logits2, labels2) + criterion(logits3, labels3) +criterion(logits4, labels4)
                            if opt.loss_merge:
                                ff_all = torch.cat((ff_all, ff4), dim=0)
                                labels_all = torch.cat((labels_all, labels4), dim=0)
                        if opt.arcface:
                            if opt.loss_merge:
                                loss += criterion_arcface(ff_all, labels_all)
                            else:
                                loss += criterion_arcface(ff, labels) + criterion_arcface(ff2, labels2) + criterion_arcface(ff3, labels3)  # /now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_arcface(ff4, labels4)  # /now_batch_size
                        if opt.cosface:
                            if opt.loss_merge:
                                loss += criterion_cosface(ff_all, labels_all)
                            else:
                                loss += criterion_cosface(ff, labels) + criterion_cosface(ff2, labels2) + criterion_cosface(ff3, labels3)  # /now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_cosface(ff4, labels4)  # /now_batch_size
                        if opt.circle:
                            if opt.loss_merge:
                                loss += criterion_circle(*convert_label_to_similarity(ff_all, labels_all)) / now_batch_size
                            else:
                                loss += criterion_circle(*convert_label_to_similarity(ff, labels)) / now_batch_size + criterion_circle(*convert_label_to_similarity(ff2, labels2)) / now_batch_size + criterion_circle(*convert_label_to_similarity(ff3, labels3)) / now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_circle(*convert_label_to_similarity(ff4, labels4)) / now_batch_size
                        if opt.triplet:
                            if opt.loss_merge:
                                hard_pairs_all = miner(ff_all, labels_all)
                                loss += criterion_triplet(ff_all, labels_all, hard_pairs_all)
                            else:
                                hard_pairs = miner(ff, labels)
                                hard_pairs2 = miner(ff2, labels2)
                                hard_pairs3 = miner(ff3, labels3)
                                loss += criterion_triplet(ff, labels, hard_pairs) + criterion_triplet(ff2, labels2, hard_pairs2) + criterion_triplet(ff3, labels3, hard_pairs3)# /now_batch_size
                                if opt.extra_Google:
                                    hard_pairs4 = miner(ff4, labels4)
                                    loss += criterion_triplet(ff4, labels4, hard_pairs4)
                        if opt.lifted:
                            if opt.loss_merge:
                                loss += criterion_lifted(ff_all, labels_all)
                            else:
                                loss += criterion_lifted(ff, labels) + criterion_lifted(ff2, labels2) + criterion_lifted(ff3, labels3)  # /now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_lifted(ff4, labels4)
                        if opt.contrast:
                            if opt.loss_merge:
                                loss += criterion_contrast(ff_all, labels_all)
                            else:
                                loss += criterion_contrast(ff, labels) + criterion_contrast(ff2,labels2) + criterion_contrast(ff3, labels3)  # /now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_contrast(ff4, labels4)
                        if opt.sphere:
                            if opt.loss_merge:
                                loss += criterion_sphere(ff_all, labels_all) / now_batch_size
                            else:
                                loss += criterion_sphere(ff, labels) / now_batch_size + criterion_sphere(ff2, labels2) / now_batch_size + criterion_sphere(ff3, labels3) / now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_sphere(ff4, labels4)

                    else:
                        _, preds = torch.max(outputs.data, 1)
                        _, preds2 = torch.max(outputs2.data, 1)
                        _, preds3 = torch.max(outputs3.data, 1)
                        if opt.loss_merge:
                            outputs_all = torch.cat((outputs, outputs2, outputs3), dim=0)
                            labels_all = torch.cat((labels, labels2, labels3), dim=0)
                            if opt.extra_Google:
                                outputs_all = torch.cat((outputs_all, outputs4), dim=0)
                                labels_all = torch.cat((labels_all, labels4), dim=0)
                            loss = 4*criterion(outputs_all, labels_all)
                        else:
                            loss = criterion(outputs, labels) + criterion(outputs2, labels2) + criterion(outputs3, labels3)
                            if opt.extra_Google:
                                loss += criterion(outputs4, labels4)

                # backward + optimize only if in training phase
                if epoch < opt.warm_epoch and phase == 'train':
                    warm_up = min(1.0, warm_up + 0.9 / warm_iteration)
                    loss *= warm_up

                if phase == 'train':
                    if fp16:  # we use optimier to backward loss
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    optimizer.step()
                    ##########
                    if opt.moving_avg < 1.0:
                        update_average(model_test, model, opt.moving_avg)

                # statistics
                if int(version[0]) > 0 or int(version[2]) > 3:  # for the new version like 0.4.0, 0.5.0 and 1.0.0
                    running_loss += loss.item() * now_batch_size
                else:  # for the old version like 0.3.0 and 0.3.1
                    running_loss += loss.data[0] * now_batch_size
                running_corrects += float(torch.sum(preds == labels.data))
                running_corrects2 += float(torch.sum(preds2 == labels2.data))
                if opt.views == 3:
                    running_corrects3 += float(torch.sum(preds3 == labels3.data))

            epoch_loss = running_loss / dataset_sizes['satellite']
            epoch_acc = running_corrects / dataset_sizes['satellite']
            epoch_acc2 = running_corrects2 / dataset_sizes['satellite']

            if opt.views == 2:
                print('{} Loss: {:.4f} Satellite_Acc: {:.4f}  Street_Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc,
                                                                                         epoch_acc2))
            elif opt.views == 3:
                epoch_acc3 = running_corrects3 / dataset_sizes['satellite']
                print('{} Loss: {:.4f} Satellite_Acc: {:.4f}  Street_Acc: {:.4f} Drone_Acc: {:.4f}'.format(phase,
                                                                                                           epoch_loss,
                                                                                                           epoch_acc,
                                                                                                           epoch_acc2,
                                                                                                           epoch_acc3))

            y_loss[phase].append(epoch_loss)
            y_err[phase].append(1.0 - epoch_acc)
            # deep copy the model
            if phase == 'train':
                scheduler.step()
            last_model_wts = model.state_dict()
            if epoch % 20 == 19:
                save_network(model, opt.name, epoch)
            # draw_curve(epoch)

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    # print('Best val Acc: {:4f}'.format(best_acc))
    # save_network(model_test, opt.name+'adapt', epoch)

    return model
    def train(self, generator, discriminator, train_data_dir):
        
        # convert device
        generator.to(self.device)
        discriminator.to(self.device)

        # create a shadow copy of the generator
        generator_shadow = copy.deepcopy(generator)

        # initialize the gen_shadow weights equal to the
        # weights of gen
        update_average(generator_shadow, generator, beta=0)
        generator_shadow.train()
        
        optimizer_G = Adam(generator.parameters(),
                            lr=self.learning_rate,
                            betas=self.betas)
        optimizer_D = Adam(discriminator.parameters(),
                            lr=self.learning_rate, 
                            betas=self.betas)

        image_size = 2 ** (self.depth + 1)
        print("Construct dataset")
        train_dataset = ImageDataset(train_data_dir, image_size)
        print("Construct optimizers")

        now = datetime.datetime.now()
        checkpoint_dir = os.path.join(self.checkpoint_root, f"{now.strftime('%Y%m%d_%H%M')}-progan")
        sample_dir = os.path.join(checkpoint_dir, "sample")
        os.makedirs(checkpoint_dir, exist_ok=True)
        os.makedirs(sample_dir, exist_ok=True)

        # training roop
        loss = losses.WGAN_GP(discriminator)
        print("Training Starts.")
        start_time = time.time()
        
        logger = Logger()
        iterations = 0
        fixed_input = torch.randn(16, self.latent_dim).to(self.device)
        fixed_input /= torch.sqrt(torch.sum(fixed_input*fixed_input + 1e-8, dim=1, keepdim=True))

        for current_depth in range(self.depth):
            current_res = 2 ** (current_depth + 1)
            print(f"Currently working on Depth: {current_depth}")
            current_res = np.power(2, current_depth + 2)
            print(f"Current resolution: {current_res} x {current_res}")

            train_dataloader = DataLoader(train_dataset, batch_size=self.batch_sizes[current_depth],
                                        shuffle=True, drop_last=True, num_workers=4)
            
            ticker = 1
            
            for epoch in range(self.num_epochs[current_depth]):
                print(f"epoch {epoch}")
                total_batches = len(train_dataloader)
                fader_point = int((self.fade_ins[current_depth] / 100)
                                * self.num_epochs[current_depth] * total_batches)
                

                for i, batch in enumerate(train_dataloader):
                    alpha = ticker / fader_point if ticker <= fader_point else 1

                    images = batch.to(self.device)

                    gan_input = torch.randn(images.shape[0], self.latent_dim).to(self.device)
                    gan_input /= torch.sqrt(torch.sum(gan_input*gan_input + 1e-8, dim=1, keepdim=True))

                    real_samples = progressive_downsampling(images, current_depth, self.depth, alpha)
                        # generate a batch of samples
                    fake_samples = generator(gan_input, current_depth, alpha).detach()
                    d_loss = loss.dis_loss(real_samples, fake_samples, current_depth, alpha)

                    # optimize discriminator
                    optimizer_D.zero_grad()
                    d_loss.backward()
                    optimizer_D.step()

                    d_loss_val = d_loss.item()

                    gan_input = torch.randn(images.shape[0], self.latent_dim).to(self.device)
                    gan_input /= torch.sqrt(torch.sum(gan_input*gan_input + 1e-8, dim=1, keepdim=True))
                    fake_samples = generator(gan_input, current_depth, alpha)
                    g_loss = loss.gen_loss(real_samples, fake_samples, current_depth, alpha)
                    
                    optimizer_G.zero_grad()
                    g_loss.backward()
                    optimizer_G.step()

                    update_average(generator_shadow, generator, beta=0.999)
                    
                    g_loss_val = g_loss.item()

                    elapsed = time.time() - start_time

                    logger.log(depth=current_depth,
                               epoch=epoch,
                               i=i,
                               iterations=iterations,
                               loss_G=g_loss_val,
                               loss_D=d_loss_val,
                               time=elapsed)
                    
                    ticker += 1
                    iterations += 1

                logger.output_log(f"{checkpoint_dir}/log.csv")
                
                generator.eval()
                with torch.no_grad():
                    sample_images = generator(fixed_input, current_depth, alpha)
                save_samples(sample_images, current_depth, sample_dir, current_depth, epoch, image_size)
                generator.train()
                if current_depth == self.depth:
                    torch.save(generator.state_dict(), f"{checkpoint_dir}/{current_depth}_{epoch}_generator.pth")