示例#1
0
def main():
    global args, min_loss, best_acc
    args = parser.parse_args()
    device_counts = torch.cuda.device_count()
    print('there is %d gpus in usage' % (device_counts))
    # save source script
    set_prefix(args.prefix, __file__)
    model = model_selector(args.model_type)
    print(model)
    if args.cuda:
        model = DataParallel(model).cuda()
    else:
        raise RuntimeError('there is no gpu')

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    # accelerate the speed of training
    cudnn.benchmark = True

    train_loader, val_loader = load_dataset()
    # class_names=['LESION', 'NORMAL']
    class_names = train_loader.dataset.class_names
    print(class_names)
    criterion = nn.BCELoss().cuda()

    # learning rate decay per epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=args.step_size,
                                           gamma=args.gamma)
    since = time.time()
    print('-' * 10)
    for epoch in range(args.epochs):
        exp_lr_scheduler.step()
        train(train_loader, model, optimizer, criterion, epoch)
        cur_loss, cur_acc = validate(model, val_loader, criterion)
        is_best = cur_loss < min_loss
        best_loss = min(cur_loss, min_loss)
        if is_best:
            best_acc = cur_acc
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model_type,
                'state_dict': model.state_dict(),
                'min_loss': best_loss,
                'acc': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    check_point = torch.load(add_prefix(args.prefix, args.best_model_path))
    print('min_loss=%.4f, best_acc=%.4f' %
          (check_point['min_loss'], check_point['acc']))
    write(vars(args), add_prefix(args.prefix, 'paras.txt'))
示例#2
0
def main():
    global args, best_acc
    args = parser.parse_args()
    # save source script
    set_prefix(args.prefix, __file__)
    model = models.densenet121(pretrained=False, num_classes=2)
    if args.cuda:
        model = DataParallel(model).cuda()
    else:
        warnings.warn('there is no gpu')

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    # accelerate the speed of training
    cudnn.benchmark = True

    train_loader, val_loader = load_dataset()
    # class_names=['LESION', 'NORMAL']
    class_names = train_loader.dataset.classes
    print(class_names)
    if args.is_focal_loss:
        print('try focal loss!!')
        criterion = FocalLoss().cuda()
    else:
        criterion = nn.CrossEntropyLoss().cuda()

    # learning rate decay per epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=args.step_size,
                                           gamma=args.gamma)
    since = time.time()
    print('-' * 10)
    for epoch in range(args.epochs):
        exp_lr_scheduler.step()
        train(train_loader, model, optimizer, criterion, epoch)
        cur_accuracy = validate(model, val_loader, criterion)
        is_best = cur_accuracy > best_acc
        best_acc = max(cur_accuracy, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': 'resnet18',
                'state_dict': model.state_dict(),
                'best_accuracy': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    # compute validate meter such as confusion matrix
    compute_validate_meter(model, add_prefix(args.prefix,
                                             args.best_model_path), val_loader)
    # save running parameter setting to json
    write(vars(args), add_prefix(args.prefix, 'paras.txt'))
class PoemImageEmbedTrainer():
    def __init__(self, train_data, test_data, sentiment_model, batchsize, load_model, device):
        self.device = device
        self.train_data = train_data
        self.test_data = test_data
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])

        self.test_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

        img_dir = 'data/image'
        self.train_set = PoemImageEmbedDataset(self.train_data, img_dir,
                                               tokenizer=self.tokenizer, max_seq_len=100,
                                               transform=self.train_transform)
        self.train_loader = DataLoader(self.train_set, batch_size=batchsize, shuffle=True, num_workers=4)

        self.test_set = PoemImageEmbedDataset(self.test_data, img_dir,
                                              tokenizer=self.tokenizer, max_seq_len=100,
                                              transform=self.test_transform)
        self.test_loader = DataLoader(self.test_set, batch_size=batchsize, num_workers=4)

        self.model = PoemImageEmbedModel(device)

        self.model = DataParallel(self.model)
        load_dataparallel(self.model.module.img_embedder.sentiment_feature, sentiment_model)
        if load_model:
            logger.info('load model from '+ load_model)
            self.model.load_state_dict(torch.load(load_model))
        self.model.to(device)
        self.optimizer = optim.Adam(list(self.model.module.poem_embedder.linear.parameters()) + \
                                    list(self.model.module.img_embedder.linear.parameters()), lr=1e-4)
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[2, 4, 6], gamma=0.33)

    def train_epoch(self, epoch, log_interval, save_interval, ckpt_file):
        self.model.train()
        running_ls = 0
        acc_ls = 0
        start = time.time()
        num_batches = len(self.train_loader)
        for i, batch in enumerate(self.train_loader):
            img1, ids1, mask1, img2, ids2, mask2 = [t.to(self.device) for t in batch]
            self.model.zero_grad()
            loss = self.model(img1, ids1, mask1, img2, ids2, mask2)
            loss.backward(torch.ones_like(loss))
            running_ls += loss.mean().item()
            acc_ls += loss.mean().item()
            self.optimizer.step()

            if (i + 1) % log_interval == 0:
                elapsed_time = time.time() - start
                iters_per_sec = (i + 1) / elapsed_time
                remaining = (num_batches - i - 1) / iters_per_sec
                remaining_time = time.strftime("%H:%M:%S", time.gmtime(remaining))

                print('[{:>2}, {:>4}/{}] running loss:{:.4} acc loss:{:.4} {:.3}iters/s {} left'.format(
                    epoch, (i + 1), num_batches, running_ls / log_interval, acc_ls /(i+1),
                    iters_per_sec, remaining_time))
                running_ls = 0

            if (i + 1) % save_interval == 0:
                self.save_model(ckpt_file)

    def save_model(self, file):
        torch.save(self.model.state_dict(), file)
                featureRs = np.concatenate((featureRs, featureR), 0)

        result = {
            'fl': featureLs,
            'fr': featureRs,
            'fold': folds,
            'flag': flags
        }
        # save tmp_result
        scipy.io.savemat('./result/tmp_result.mat', result)
        accs = evaluation_10_fold('./result/tmp_result.mat')
        _print('    ave: {:.4f}'.format(np.mean(accs) * 100))

    # save model
    if epoch % SAVE_FREQ == 0:
        msg = 'Saving checkpoint: {}'.format(epoch)
        _print(msg)
        if multi_gpus:
            net_state_dict = net.module.state_dict()
        else:
            net_state_dict = net.state_dict()
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        torch.jit.save(torch.jit.script(net),
                       os.path.join(save_dir, 'mobileface%03d.pt' % epoch))
        torch.save({
            'epoch': epoch,
            'net_state_dict': net_state_dict
        }, os.path.join(save_dir, '%03d.ckpt' % epoch))
print('finishing training')
示例#5
0
        z = Variable(z, volatile=True)
    random_image = G(z)
    fixed_image = G(fixed_z)
    G.train()  # stop test and start train

    p = DIR + '/Random_results/MNIST_GAN_' + str(epoch + 1) + '.png'
    fixed_p = DIR + '/Fixed_results/MNIST_GAN_' + str(epoch + 1) + '.png'
    utils.save_result(random_image, (epoch+1), save=True, path=p)
    utils.save_result(fixed_image, (epoch+1), save=True, path=fixed_p)
    train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
    train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))
    train_hist['per_epoch_times'].append(per_epoch_time)

end_time = time()
total_time = end_time - end_time
print("Avg per epoch time: %.2f, total %d epochs time: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_times'])), EPOCH, total_time))
print("Training finish!!!...")

# save parameters
torch.save(G.state_dict(), DIR + "/generator_param.pkl")
torch.save(D.state_dict(), DIR + "/discriminator_param.pkl")

# save history
p = DIR + '/history.png'
utils.save_history(train_hist, save=True, path=p)

# save animation
prefix = DIR + '/Fixed_results/MNIST_GAN_'
p = DIR + '/animation.gif'
utils.save_animation(EPOCH, prefix=prefix, path=p)
示例#6
0
    def train(self):

        if self.net == 'vgg16':
            photo_net = DataParallel(self._get_vgg16()).cuda()
            sketch_net = DataParallel(self._get_vgg16()).cuda()
        elif self.net == 'resnet34':
            photo_net = DataParallel(self._get_resnet34()).cuda()
            sketch_net = DataParallel(self._get_resnet34()).cuda()
        elif self.net == 'resnet50':
            photo_net = DataParallel(self._get_resnet50()).cuda()
            sketch_net = DataParallel(self._get_resnet50()).cuda()

        if self.fine_tune:
            photo_net_root = self.model_root
            sketch_net_root = self.model_root.replace('photo', 'sketch')

            photo_net.load_state_dict(
                t.load(photo_net_root, map_location=t.device('cpu')))
            sketch_net.load_state_dict(
                t.load(sketch_net_root, map_location=t.device('cpu')))

        print('net')
        print(photo_net)

        # triplet_loss = nn.TripletMarginLoss(margin=self.margin, p=self.p).cuda()
        photo_cat_loss = nn.CrossEntropyLoss().cuda()
        sketch_cat_loss = nn.CrossEntropyLoss().cuda()

        my_triplet_loss = TripletLoss().cuda()

        # optimizer
        photo_optimizer = t.optim.Adam(photo_net.parameters(), lr=self.lr)
        sketch_optimizer = t.optim.Adam(sketch_net.parameters(), lr=self.lr)

        if self.vis:
            vis = Visualizer(self.env)

        triplet_loss_meter = AverageValueMeter()
        sketch_cat_loss_meter = AverageValueMeter()
        photo_cat_loss_meter = AverageValueMeter()

        data_loader = TripleDataLoader(self.dataloader_opt)
        dataset = data_loader.load_data()

        for epoch in range(self.epochs):

            print('---------------{0}---------------'.format(epoch))

            if self.test and epoch % self.test_f == 0:

                tester_config = Config()
                tester_config.test_bs = 128
                tester_config.photo_net = photo_net
                tester_config.sketch_net = sketch_net

                tester_config.photo_test = self.photo_test
                tester_config.sketch_test = self.sketch_test

                tester = Tester(tester_config)
                test_result = tester.test_instance_recall()

                result_key = list(test_result.keys())
                vis.plot('recall',
                         np.array([
                             test_result[result_key[0]],
                             test_result[result_key[1]]
                         ]),
                         legend=[result_key[0], result_key[1]])
                if self.save_model:
                    t.save(
                        photo_net.state_dict(), self.save_dir + '/photo' +
                        '/photo_' + self.net + '_%s.pth' % epoch)
                    t.save(
                        sketch_net.state_dict(), self.save_dir + '/sketch' +
                        '/sketch_' + self.net + '_%s.pth' % epoch)

            photo_net.train()
            sketch_net.train()

            for ii, data in enumerate(dataset):

                photo_optimizer.zero_grad()
                sketch_optimizer.zero_grad()

                photo = data['P'].cuda()
                sketch = data['S'].cuda()
                label = data['L'].cuda()

                p_cat, p_feature = photo_net(photo)
                s_cat, s_feature = sketch_net(sketch)

                # category loss
                p_cat_loss = photo_cat_loss(p_cat, label)
                s_cat_loss = sketch_cat_loss(s_cat, label)

                photo_cat_loss_meter.add(p_cat_loss.item())
                sketch_cat_loss_meter.add(s_cat_loss.item())

                # triplet loss
                loss = p_cat_loss + s_cat_loss

                # tri_record = 0.
                '''
                for i in range(self.batch_size):
                    # negative
                    negative_feature = t.cat([p_feature[0:i, :], p_feature[i + 1:, :]], dim=0)
                    # print('negative_feature.size :', negative_feature.size())
                    # photo_feature
                    anchor_feature = s_feature[i, :]
                    anchor_feature = anchor_feature.expand_as(negative_feature)
                    # print('anchor_feature.size :', anchor_feature.size())

                    # positive
                    positive_feature = p_feature[i, :]
                    positive_feature = positive_feature.expand_as(negative_feature)
                    # print('positive_feature.size :', positive_feature.size())

                    tri_loss = triplet_loss(anchor_feature, positive_feature, negative_feature)

                    tri_record = tri_record + tri_loss

                    # print('tri_loss :', tri_loss)
                    loss = loss + tri_loss
                '''
                # print('tri_record : ', tri_record)

                my_tri_loss = my_triplet_loss(
                    s_feature, p_feature) / (self.batch_size - 1)
                triplet_loss_meter.add(my_tri_loss.item())
                # print('my_tri_loss : ', my_tri_loss)

                # print(tri_record - my_tri_loss)
                loss = loss + my_tri_loss
                # print('loss :', loss)
                # loss = loss / opt.batch_size

                loss.backward()

                photo_optimizer.step()
                sketch_optimizer.step()

                if self.vis:
                    vis.plot('triplet_loss',
                             np.array([
                                 triplet_loss_meter.value()[0],
                                 photo_cat_loss_meter.value()[0],
                                 sketch_cat_loss_meter.value()[0]
                             ]),
                             legend=[
                                 'triplet_loss', 'photo_cat_loss',
                                 'sketch_cat_loss'
                             ])

                triplet_loss_meter.reset()
                photo_cat_loss_meter.reset()
                sketch_cat_loss_meter.reset()
示例#7
0
def train_model(train_dataset, train_num_each, val_dataset, val_num_each):
    num_train = len(train_dataset)
    num_val = len(val_dataset)

    train_useful_start_idx = get_useful_start_idx(sequence_length,
                                                  train_num_each)

    val_useful_start_idx = get_useful_start_idx(sequence_length, val_num_each)

    num_train_we_use = len(train_useful_start_idx) // num_gpu * num_gpu
    num_val_we_use = len(val_useful_start_idx) // num_gpu * num_gpu
    # num_train_we_use = 4
    # num_val_we_use = 800

    train_we_use_start_idx = train_useful_start_idx[0:num_train_we_use]
    val_we_use_start_idx = val_useful_start_idx[0:num_val_we_use]

    train_idx = []
    for i in range(num_train_we_use):
        for j in range(sequence_length):
            train_idx.append(train_we_use_start_idx[i] + j)

    val_idx = []
    for i in range(num_val_we_use):
        for j in range(sequence_length):
            val_idx.append(val_we_use_start_idx[i] + j)

    num_train_all = len(train_idx)
    num_val_all = len(val_idx)

    print('num train start idx : {:6d}'.format(len(train_useful_start_idx)))
    print('last idx train start: {:6d}'.format(train_useful_start_idx[-1]))
    print('num of train dataset: {:6d}'.format(num_train))
    print('num of train we use : {:6d}'.format(num_train_we_use))
    print('num of all train use: {:6d}'.format(num_train_all))
    print('num valid start idx : {:6d}'.format(len(val_useful_start_idx)))
    print('last idx valid start: {:6d}'.format(val_useful_start_idx[-1]))
    print('num of valid dataset: {:6d}'.format(num_val))
    print('num of valid we use : {:6d}'.format(num_val_we_use))
    print('num of all valid use: {:6d}'.format(num_val_all))

    train_loader = DataLoader(train_dataset,
                              batch_size=train_batch_size,
                              sampler=train_idx,
                              num_workers=workers,
                              pin_memory=False)
    val_loader = DataLoader(val_dataset,
                            batch_size=val_batch_size,
                            sampler=val_idx,
                            num_workers=workers,
                            pin_memory=False)
    model = multi_lstm_4loss()
    sig_f = nn.Sigmoid()

    if use_gpu:
        model = model.cuda()
        sig_f = sig_f.cuda()
    model = DataParallel(model)
    criterion_1 = nn.BCEWithLogitsLoss(size_average=False)
    criterion_2 = nn.CrossEntropyLoss(size_average=False)

    if multi_optim == 0:
        if optimizer_choice == 0:
            optimizer = optim.SGD(model.parameters(),
                                  lr=learning_rate,
                                  momentum=momentum,
                                  dampening=dampening,
                                  weight_decay=weight_decay,
                                  nesterov=use_nesterov)
            if sgd_adjust_lr == 0:
                exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                                       step_size=sgd_adjust_lr,
                                                       gamma=sgd_gamma)
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(
                    optimizer, 'min')
        elif optimizer_choice == 1:
            optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    elif multi_optim == 1:
        if optimizer_choice == 0:
            optimizer = optim.SGD([
                {
                    'params': model.module.share.parameters()
                },
                {
                    'params': model.module.lstm.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_tool.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_tool1.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_tool2.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_phase1.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_phase2.parameters(),
                    'lr': learning_rate
                },
            ],
                                  lr=learning_rate / 10,
                                  momentum=momentum,
                                  dampening=dampening,
                                  weight_decay=weight_decay,
                                  nesterov=use_nesterov)
            if sgd_adjust_lr == 0:
                exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                                       step_size=sgd_adjust_lr,
                                                       gamma=sgd_gamma)
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(
                    optimizer, 'min')
        elif optimizer_choice == 1:
            optimizer = optim.Adam([
                {
                    'params': model.module.share.parameters()
                },
                {
                    'params': model.module.lstm.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_tool.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_tool1.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_tool2.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_phase1.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_phase2.parameters(),
                    'lr': learning_rate
                },
            ],
                                   lr=learning_rate / 10)

    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_accuracy_1 = 0.0
    best_val_accuracy_2 = 0.0  # judge by accu2
    correspond_train_acc_1 = 0.0
    correspond_train_acc_2 = 0.0

    record_np = np.zeros([epochs, 8])

    for epoch in range(epochs):
        # np.random.seed(epoch)
        np.random.shuffle(train_we_use_start_idx)
        train_idx = []
        for i in range(num_train_we_use):
            for j in range(sequence_length):
                train_idx.append(train_we_use_start_idx[i] + j)

        train_loader = DataLoader(train_dataset,
                                  batch_size=train_batch_size,
                                  sampler=train_idx,
                                  num_workers=workers,
                                  pin_memory=False)

        model.train()
        train_loss_11 = 0.0
        train_loss_12 = 0.0
        train_loss_21 = 0.0
        train_loss_22 = 0.0
        train_corrects_11 = 0
        train_corrects_12 = 0
        train_corrects_21 = 0
        train_corrects_22 = 0

        train_start_time = time.time()
        for data in train_loader:
            inputs, labels_1, labels_2 = data
            if use_gpu:
                inputs = Variable(inputs.cuda())
                labels_1 = Variable(labels_1.cuda())
                labels_2 = Variable(labels_2.cuda())
            else:
                inputs = Variable(inputs)
                labels_1 = Variable(labels_1)
                labels_2 = Variable(labels_2)

            optimizer.zero_grad()

            outputs_11, outputs_12, outputs_21, outputs_22 = model.forward(
                inputs)

            _, preds_12 = torch.max(outputs_12.data, 1)
            _, preds_22 = torch.max(outputs_22.data, 1)

            sig_out_11 = sig_f(outputs_11.data)
            sig_out_21 = sig_f(outputs_21.data)

            preds_11 = torch.ByteTensor(sig_out_11.cpu() > 0.5)
            preds_11 = preds_11.long()
            train_corrects_11 += torch.sum(preds_11 == labels_1.data.cpu())
            preds_21 = torch.ByteTensor(sig_out_21.cpu() > 0.5)
            preds_21 = preds_21.long()
            train_corrects_21 += torch.sum(preds_21 == labels_1.data.cpu())

            labels_1 = Variable(labels_1.data.float())
            loss_11 = criterion_1(outputs_11, labels_1)
            loss_21 = criterion_1(outputs_21, labels_1)

            loss_12 = criterion_2(outputs_12, labels_2)
            loss_22 = criterion_2(outputs_22, labels_2)
            loss = loss_11 + loss_12 + loss_21 + loss_22
            loss.backward()
            optimizer.step()

            train_loss_11 += loss_11.data[0]
            train_loss_12 += loss_12.data[0]
            train_loss_21 += loss_21.data[0]
            train_loss_22 += loss_22.data[0]
            train_corrects_12 += torch.sum(preds_12 == labels_2.data)
            train_corrects_22 += torch.sum(preds_22 == labels_2.data)

        train_elapsed_time = time.time() - train_start_time
        train_accuracy_11 = train_corrects_11 / num_train_all / 7
        train_accuracy_21 = train_corrects_21 / num_train_all / 7
        train_accuracy_12 = train_corrects_12 / num_train_all
        train_accuracy_22 = train_corrects_22 / num_train_all
        train_average_loss_11 = train_loss_11 / num_train_all / 7
        train_average_loss_21 = train_loss_21 / num_train_all / 7
        train_average_loss_12 = train_loss_12 / num_train_all
        train_average_loss_22 = train_loss_22 / num_train_all

        # begin eval

        model.eval()
        val_loss_11 = 0.0
        val_loss_12 = 0.0
        val_loss_21 = 0.0
        val_loss_22 = 0.0
        val_corrects_11 = 0
        val_corrects_12 = 0
        val_corrects_21 = 0
        val_corrects_22 = 0

        val_start_time = time.time()
        for data in val_loader:
            inputs, labels_1, labels_2 = data
            labels_2 = labels_2[(sequence_length - 1)::sequence_length]
            if use_gpu:
                inputs = Variable(inputs.cuda(), volatile=True)
                labels_1 = Variable(labels_1.cuda(), volatile=True)
                labels_2 = Variable(labels_2.cuda(), volatile=True)
            else:
                inputs = Variable(inputs, volatile=True)
                labels_1 = Variable(labels_1, volatile=True)
                labels_2 = Variable(labels_2, volatile=True)

            # if crop_type == 0 or crop_type == 1:
            #     outputs_1, outputs_2 = model.forward(inputs)
            # elif crop_type == 5:
            #     inputs = inputs.permute(1, 0, 2, 3, 4).contiguous()
            #     inputs = inputs.view(-1, 3, 224, 224)
            #     outputs_1, outputs_2 = model.forward(inputs)
            #     outputs_1 = outputs_1.view(5, -1, 7)
            #     outputs_1 = torch.mean(outputs_1, 0)
            #     outputs_2 = outputs_2.view(5, -1, 7)
            #     outputs_2 = torch.mean(outputs_2, 0)
            # elif crop_type == 10:
            #     inputs = inputs.permute(1, 0, 2, 3, 4).contiguous()
            #     inputs = inputs.view(-1, 3, 224, 224)
            #     outputs_1, outputs_2 = model.forward(inputs)
            #     outputs_1 = outputs_1.view(10, -1, 7)
            #     outputs_1 = torch.mean(outputs_1, 0)
            #     outputs_2 = outputs_2.view(10, -1, 7)
            #     outputs_2 = torch.mean(outputs_2, 0)
            outputs_11, outputs_12, outputs_21, outputs_22 = model.forward(
                inputs)
            outputs_12 = outputs_12[sequence_length - 1::sequence_length]
            outputs_22 = outputs_22[sequence_length - 1::sequence_length]

            _, preds_12 = torch.max(outputs_12.data, 1)
            _, preds_22 = torch.max(outputs_22.data, 1)

            sig_out_11 = sig_f(outputs_11.data)
            sig_out_21 = sig_f(outputs_21.data)

            preds_11 = torch.ByteTensor(sig_out_11.cpu() > 0.5)
            preds_11 = preds_11.long()
            train_corrects_11 += torch.sum(preds_11 == labels_1.data.cpu())
            preds_21 = torch.ByteTensor(sig_out_21.cpu() > 0.5)
            preds_21 = preds_21.long()
            train_corrects_21 += torch.sum(preds_21 == labels_1.data.cpu())

            labels_1 = Variable(labels_1.data.float())
            loss_11 = criterion_1(outputs_11, labels_1)
            loss_21 = criterion_1(outputs_21, labels_1)

            loss_12 = criterion_2(outputs_12, labels_2)
            loss_22 = criterion_2(outputs_22, labels_2)

            val_loss_11 += loss_11.data[0]
            val_loss_12 += loss_12.data[0]
            val_loss_21 += loss_21.data[0]
            val_loss_22 += loss_22.data[0]
            val_corrects_12 += torch.sum(preds_12 == labels_2.data)
            val_corrects_22 += torch.sum(preds_22 == labels_2.data)

        val_elapsed_time = time.time() - val_start_time
        val_accuracy_11 = val_corrects_11 / num_val_all / 7
        val_accuracy_21 = val_corrects_21 / num_val_all / 7
        val_accuracy_12 = val_corrects_12 / num_val_we_use
        val_accuracy_22 = val_corrects_22 / num_val_we_use
        val_average_loss_11 = val_loss_11 / num_val_all / 7
        val_average_loss_21 = val_loss_21 / num_val_all / 7
        val_average_loss_12 = val_loss_12 / num_val_we_use
        val_average_loss_22 = val_loss_22 / num_val_we_use

        print('epoch: {:4d}'
              ' train time: {:2.0f}m{:2.0f}s'
              ' train accu_11: {:.4f}'
              ' train accu_21: {:.4f}'
              ' valid time: {:2.0f}m{:2.0f}s'
              ' valid accu_11: {:.4f}'
              ' valid accu_21: {:.4f}'.format(
                  epoch, train_elapsed_time // 60, train_elapsed_time % 60,
                  train_accuracy_11, train_accuracy_21, val_elapsed_time // 60,
                  val_elapsed_time % 60, val_accuracy_11, val_accuracy_21))
        print('epoch: {:4d}'
              ' train time: {:2.0f}m{:2.0f}s'
              ' train accu_12: {:.4f}'
              ' train accu_22: {:.4f}'
              ' valid time: {:2.0f}m{:2.0f}s'
              ' valid accu_12: {:.4f}'
              ' valid accu_22: {:.4f}'.format(
                  epoch, train_elapsed_time // 60, train_elapsed_time % 60,
                  train_accuracy_12, train_accuracy_22, val_elapsed_time // 60,
                  val_elapsed_time % 60, val_accuracy_12, val_accuracy_22))

        if optimizer_choice == 0:
            if sgd_adjust_lr == 0:
                exp_lr_scheduler.step()
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler.step(val_average_loss_11 +
                                      val_average_loss_12 +
                                      val_average_loss_21 +
                                      val_average_loss_22)
示例#8
0
class ConditionalProGAN:
    """ Wrapper around the Generator and the Conditional Discriminator """

    def __init__(self, num_classes, depth=7, latent_size=512,
                 learning_rate=0.001, beta_1=0, beta_2=0.99,
                 eps=1e-8, drift=0.001, n_critic=1, use_eql=True,
                 loss="wgan-gp", use_ema=True, ema_decay=0.999,
                 device=th.device("cpu")):
        """
        constructor for the class
        :param num_classes: number of classes required for the conditional gan
        :param depth: depth of the GAN (will be used for each generator and discriminator)
        :param latent_size: latent size of the manifold used by the GAN
        :param learning_rate: learning rate for Adam
        :param beta_1: beta_1 for Adam
        :param beta_2: beta_2 for Adam
        :param eps: epsilon for Adam
        :param n_critic: number of times to update discriminator
                         (Used only if loss is wgan or wgan-gp)
        :param drift: drift penalty for the
                      (Used only if loss is wgan or wgan-gp)
        :param use_eql: whether to use equalized learning rate
        :param loss: the loss function to be used
                     Can either be a string =>
                          ["wgan-gp", "wgan", "lsgan", "lsgan-with-sigmoid",
                          "hinge", "standard-gan" or "relativistic-hinge"]
                     Or an instance of ConditionalGANLoss
        :param use_ema: boolean for whether to use exponential moving averages
        :param ema_decay: value of mu for ema
        :param device: device to run the GAN on (GPU / CPU)
        """

        from torch.optim import Adam
        from torch.nn import DataParallel

        # Create the Generator and the Discriminator
        self.gen = Generator(depth, latent_size, use_eql=use_eql).to(device)
        self.dis = ConditionalDiscriminator(
            num_classes, height=depth,
            feature_size=latent_size,
            use_eql=use_eql).to(device)

        # if code is to be run on GPU, we can use DataParallel:
        if device == th.device("cuda"):
            self.gen = DataParallel(self.gen)
            self.dis = DataParallel(self.dis)

        # state of the object
        self.latent_size = latent_size
        self.depth = depth
        self.use_ema = use_ema
        self.num_classes = num_classes  # required for matching aware
        self.ema_decay = ema_decay
        self.n_critic = n_critic
        self.use_eql = use_eql
        self.device = device
        self.drift = drift

        # define the optimizers for the discriminator and generator
        self.gen_optim = Adam(self.gen.parameters(), lr=learning_rate,
                              betas=(beta_1, beta_2), eps=eps)

        self.dis_optim = Adam(self.dis.parameters(), lr=learning_rate,
                              betas=(beta_1, beta_2), eps=eps)

        # define the loss function used for training the GAN
        self.loss = self.__setup_loss(loss)

        # setup the ema for the generator
        if self.use_ema:
            from pro_gan_pytorch.CustomLayers import update_average

            # create a shadow copy of the generator
            self.gen_shadow = copy.deepcopy(self.gen)

            # updater function:
            self.ema_updater = update_average

            # initialize the gen_shadow weights equal to the
            # weights of gen
            self.ema_updater(self.gen_shadow, self.gen, beta=0)

    def __setup_loss(self, loss):
        import pro_gan_pytorch.Losses as losses

        if isinstance(loss, str):
            loss = loss.lower()  # lowercase the string
            if loss == "wgan":
                loss = losses.CondWGAN_GP(self.dis, self.drift, use_gp=False)
                # note if you use just wgan, you will have to use weight clipping
                # in order to prevent gradient exploding

            elif loss == "wgan-gp":
                loss = losses.CondWGAN_GP(self.dis, self.drift, use_gp=True)

            elif loss == "lsgan":
                loss = losses.CondLSGAN(self.dis)

            elif loss == "lsgan-with-sigmoid":
                loss = losses.CondLSGAN_SIGMOID(self.dis)

            elif loss == "hinge":
                loss = losses.CondHingeGAN(self.dis)

            elif loss == "standard-gan":
                loss = losses.CondStandardGAN(self.dis)

            elif loss == "relativistic-hinge":
                loss = losses.CondRelativisticAverageHingeGAN(self.dis)

            else:
                raise ValueError("Unknown loss function requested")

        elif not isinstance(loss, losses.ConditionalGANLoss):
            raise ValueError("loss is neither an instance of GANLoss nor a string")

        return loss

    def __progressive_downsampling(self, real_batch, depth, alpha):
        """
        private helper for downsampling the original images in order to facilitate the
        progressive growing of the layers.
        :param real_batch: batch of real samples
        :param depth: depth at which training is going on
        :param alpha: current value of the fader alpha
        :return: real_samples => modified real batch of samples
        """

        from torch.nn import AvgPool2d
        from torch.nn.functional import interpolate

        # downsample the real_batch for the given depth
        down_sample_factor = int(np.power(2, self.depth - depth - 1))
        prior_downsample_factor = max(int(np.power(2, self.depth - depth)), 0)

        ds_real_samples = AvgPool2d(down_sample_factor)(real_batch)

        if depth > 0:
            prior_ds_real_samples = interpolate(AvgPool2d(prior_downsample_factor)(real_batch),
                                                scale_factor=2)
        else:
            prior_ds_real_samples = ds_real_samples

        # real samples are a combination of ds_real_samples and prior_ds_real_samples
        real_samples = (alpha * ds_real_samples) + ((1 - alpha) * prior_ds_real_samples)

        # return the so computed real_samples
        return real_samples

    def optimize_discriminator(self, noise, real_batch, labels, depth, alpha):
        """
        performs one step of weight update on discriminator using the batch of data
        :param noise: input noise of sample generation
        :param real_batch: real samples batch
        :param labels: (conditional classes) should be a list of integers
        :param depth: current depth of optimization
        :param alpha: current alpha for fade-in
        :return: current loss value
        """

        real_samples = self.__progressive_downsampling(real_batch, depth, alpha)

        loss_val = 0
        for _ in range(self.n_critic):
            # generate a batch of samples
            fake_samples = self.gen(noise, depth, alpha).detach()

            loss = self.loss.dis_loss(real_samples, fake_samples,
                                      labels, depth, alpha)

            # optimize discriminator
            self.dis_optim.zero_grad()
            loss.backward()
            self.dis_optim.step()

            loss_val += loss.item()

        return loss_val / self.n_critic

    def optimize_generator(self, noise, real_batch, labels, depth, alpha):
        """
        performs one step of weight update on generator for the given batch_size
        :param noise: input random noise required for generating samples
        :param real_batch: real batch of samples (real samples)
        :param labels: labels for conditional discrimination
        :param depth: depth of the network at which optimization is done
        :param alpha: value of alpha for fade-in effect
        :return: current loss (Wasserstein estimate)
        """

        # create batch of real samples
        real_samples = self.__progressive_downsampling(real_batch, depth, alpha)

        # generate fake samples:
        fake_samples = self.gen(noise, depth, alpha)

        # TODO_complete:
        # Change this implementation for making it compatible for relativisticGAN
        loss = self.loss.gen_loss(real_samples, fake_samples, labels, depth, alpha)

        # optimize the generator
        self.gen_optim.zero_grad()
        loss.backward()
        self.gen_optim.step()

        # if use_ema is true, apply ema to the generator parameters
        if self.use_ema:
            self.ema_updater(self.gen_shadow, self.gen, self.ema_decay)

        # return the loss value
        return loss.item()

    @staticmethod
    def create_grid(samples, scale_factor, img_file):
        """
        utility function to create a grid of GAN samples
        :param samples: generated samples for storing
        :param scale_factor: factor for upscaling the image
        :param img_file: name of file to write
        :return: None (saves a file)
        """
        from torchvision.utils import save_image
        from torch.nn.functional import interpolate

        # upsample the image
        if scale_factor > 1:
            samples = interpolate(samples, scale_factor=scale_factor)

        # save the images:
        save_image(samples, img_file, nrow=int(np.sqrt(len(samples))),
                   normalize=True, scale_each=True)

    @staticmethod
    def __save_label_info_file(label_file, labels):
        """
        utility method for saving a file with labels
        :param label_file: path to the file to be written
        :param labels: label tensor
        :return: None (writes file to disk)
        """
        # write file with the labels written one per line
        with open(label_file, "w") as fp:
            for label in labels:
                fp.write(str(label.item()) + "\n")

    def one_hot_encode(self, labels):
        """
        utility method to one-hot encode the labels
        :param labels: tensor of labels (Batch)
        :return: enc_label: encoded one_hot label
        """
        if not hasattr(self, "label_oh_encoder"):
            self.label_oh_encoder = th.nn.Embedding(self.num_classes, self.num_classes)
            self.label_oh_encoder.weight.data = th.eye(self.num_classes)

        return self.label_oh_encoder(labels.view(-1))

    def train(self, dataset, epochs, batch_sizes,
              fade_in_percentage, start_depth=0, num_workers=3, feedback_factor=100,
              log_dir="./models/", sample_dir="./samples/", save_dir="./models/",
              checkpoint_factor=1):
        """
        Utility method for training the ProGAN. Note that you don't have to necessarily use this
        you can use the optimize_generator and optimize_discriminator for your own training routine.
        :param dataset: object of the dataset used for training.
                        Note that this is not the dataloader (we create dataloader in this method
                        since the batch_sizes for resolutions can be different).
                        Get_item should return (Image, label) in that order
        :param epochs: list of number of epochs to train the network for every resolution
        :param batch_sizes: list of batch_sizes for every resolution
        :param fade_in_percentage: list of percentages of epochs per resolution
                                   used for fading in the new layer
                                   not used for first resolution, but dummy value still needed.
        :param start_depth: start training from this depth. def=0
        :param num_workers: number of workers for reading the data. def=3
        :param feedback_factor: number of logs per epoch. def=100
        :param log_dir: directory for saving the loss logs. def="./models/"
        :param sample_dir: directory for saving the generated samples. def="./samples/"
        :param checkpoint_factor: save model after these many epochs.
                                  Note that only one model is stored per resolution.
                                  during one resolution, the checkpoint will be updated (Rewritten)
                                  according to this factor.
        :param save_dir: directory for saving the models (.pth files)
        :return: None (Writes multiple files to disk)
        """
        from pro_gan_pytorch.DataTools import get_data_loader

        assert self.depth == len(batch_sizes), "batch_sizes not compatible with depth"

        # turn the generator and discriminator into train mode
        self.gen.train()
        self.dis.train()
        if self.use_ema:
            self.gen_shadow.train()

        # create a global time counter
        global_time = time.time()

        # create fixed_input for debugging
        temp_data_loader = get_data_loader(dataset, batch_sizes[0], num_workers=3)
        _, fx_labels = next(iter(temp_data_loader))
        # reshape them properly
        fixed_labels = self.one_hot_encode(fx_labels.view(-1, 1)).to(self.device)
        fixed_input = th.randn(fixed_labels.shape[0],
                               self.latent_size - self.num_classes).to(self.device)
        fixed_input = th.cat((fixed_labels, fixed_input), dim=-1)
        del temp_data_loader  # delete the temp data_loader since it is not required anymore

        os.makedirs(sample_dir, exist_ok=True)  # make sure the directory exists
        self.__save_label_info_file(os.path.join(sample_dir, "labels.txt"), fx_labels)

        print("Starting the training process ... ")
        for current_depth in range(start_depth, self.depth):

            print("\n\nCurrently working on Depth: ", current_depth)
            current_res = np.power(2, current_depth + 2)
            print("Current resolution: %d x %d" % (current_res, current_res))

            data = get_data_loader(dataset, batch_sizes[current_depth], num_workers)
            ticker = 1

            for epoch in range(1, epochs[current_depth] + 1):
                start = timeit.default_timer()  # record time at the start of epoch

                print("\nEpoch: %d" % epoch)
                total_batches = len(iter(data))

                fader_point = int((fade_in_percentage[current_depth] / 100)
                                  * epochs[current_depth] * total_batches)

                step = 0  # counter for number of iterations

                for (i, batch) in enumerate(data, 1):
                    # calculate the alpha for fading in the layers
                    alpha = ticker / fader_point if ticker <= fader_point else 1

                    # extract current batch of data for training
                    images, labels = batch
                    images = images.to(self.device)
                    labels = labels.view(-1, 1)

                    # create the input to the Generator
                    label_information = self.one_hot_encode(labels).to(self.device)
                    latent_vector = th.randn(images.shape[0],
                                             self.latent_size - self.num_classes).to(self.device)
                    gan_input = th.cat((label_information, latent_vector), dim=-1)

                    # optimize the discriminator:
                    dis_loss = self.optimize_discriminator(gan_input, images,
                                                           labels, current_depth, alpha)

                    # optimize the generator:
                    gen_loss = self.optimize_generator(gan_input, images,
                                                       labels, current_depth, alpha)

                    # provide a loss feedback
                    if i % int(total_batches / feedback_factor) == 0 or i == 1:
                        elapsed = time.time() - global_time
                        elapsed = str(datetime.timedelta(seconds=elapsed))
                        print("Elapsed: [%s]  batch: %d  d_loss: %f  g_loss: %f"
                              % (elapsed, i, dis_loss, gen_loss))

                        # also write the losses to the log file:
                        os.makedirs(log_dir, exist_ok=True)
                        log_file = os.path.join(log_dir, "loss_" + str(current_depth) + ".log")
                        with open(log_file, "a") as log:
                            log.write(str(step) + "\t" + str(dis_loss) +
                                      "\t" + str(gen_loss) + "\n")

                        # create a grid of samples and save it
                        os.makedirs(sample_dir, exist_ok=True)
                        gen_img_file = os.path.join(sample_dir, "gen_" + str(current_depth) +
                                                    "_" + str(epoch) + "_" +
                                                    str(i) + ".png")

                        # this is done to allow for more GPU space
                        self.gen_optim.zero_grad()
                        self.dis_optim.zero_grad()
                        with th.no_grad():
                            self.create_grid(
                                samples=self.gen(
                                    fixed_input,
                                    current_depth,
                                    alpha
                                ) if not self.use_ema
                                else self.gen_shadow(
                                    fixed_input,
                                    current_depth,
                                    alpha
                                ),
                                scale_factor=int(np.power(2, self.depth - current_depth - 1)),
                                img_file=gen_img_file,
                            )

                    # increment the alpha ticker and the step
                    ticker += 1
                    step += 1

                stop = timeit.default_timer()
                print("Time taken for epoch: %.3f secs" % (stop - start))

                if epoch % checkpoint_factor == 0 or epoch == 1 or epoch == epochs[current_depth]:
                    os.makedirs(save_dir, exist_ok=True)
                    gen_save_file = os.path.join(save_dir, "GAN_GEN_" + str(current_depth) + ".pth")
                    dis_save_file = os.path.join(save_dir, "GAN_DIS_" + str(current_depth) + ".pth")
                    gen_optim_save_file = os.path.join(save_dir,
                                                       "GAN_GEN_OPTIM_" + str(current_depth)
                                                       + ".pth")
                    dis_optim_save_file = os.path.join(save_dir,
                                                       "GAN_DIS_OPTIM_" + str(current_depth)
                                                       + ".pth")

                    th.save(self.gen.state_dict(), gen_save_file)
                    th.save(self.dis.state_dict(), dis_save_file)
                    th.save(self.gen_optim.state_dict(), gen_optim_save_file)
                    th.save(self.dis_optim.state_dict(), dis_optim_save_file)

                    # also save the shadow generator if use_ema is True
                    if self.use_ema:
                        gen_shadow_save_file = os.path.join(save_dir, "GAN_GEN_SHADOW_" +
                                                            str(current_depth) + ".pth")
                        th.save(self.gen_shadow.state_dict(), gen_shadow_save_file)

        # put the gen, shadow_gen and dis in eval mode
        self.gen.eval()
        self.dis.eval()
        if self.use_ema:
            self.gen_shadow.eval()

        print("Training completed ...")
示例#9
0
def main(parser, logger):
    print('--> Preparing Dataset:')
    trainset = OmniglotDataset(mode='train', root=parser.dataset_root)
    trainloader = data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=0)
    valset = dataloader(parser, 'val')
    testset = dataloader(parser, 'test')
    print('--> Building Model:')
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model = Network.resnet18().to(device)
    model = DataParallel(model)
    metric = ArcMarginProduct(256, len(np.unique(trainset.y)), s=30, m=0.5).to(device)
    metric = DataParallel(metric)
    criterion = torch.nn.CrossEntropyLoss()
    print('--> Initializing Optimizer and Scheduler:')
    optimizer = torch.optim.Adam(
        [{'params':model.parameters(), 'weight_decay':5e-4},
         {'params':[metric.weight], 'weight_decay':5e-4}],
        lr=parser.learning_rate, weight_decay=0.0005)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
                                                gamma=parser.lr_scheduler_gamma,
                                                step_size=parser.lr_scheduler_step)
    best_acc = 0
    best_state = model.state_dict()
    for epoch in range(parser.epochs):
        print('\nEpoch: %d' % epoch)
        # Training
        train_loss = 0
        train_acc = 0
        train_correct = 0
        train_total = 0
        model.train()
        for batch_index, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device).long()
            feature = model(inputs)
            output = metric(feature, targets)
            loss = criterion(output, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = output.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()
        scheduler.step()
        train_acc = 100.*train_correct / train_total
        print('Training Loss: {} | Accuracy: {}'.format(train_loss/train_total, train_acc))
        # Validating
        val_correct = 0
        val_total = 0
        model.eval()
        for batch_index, (inputs, targets) in enumerate(valset):
            inputs = inputs.to(device)
            targets = targets.to(device)
            feature = model(inputs)
            correct = eval(input=feature, target=targets, n_support=parser.num_support_val)
            val_correct += correct
            val_total += parser.classes_per_it_val * parser.num_query_val
        val_acc = 100.*val_correct / val_total
        print('Validating Accuracy: {}'.format(val_acc))
        if val_acc > best_acc:
            best_acc = val_acc
            best_state = model.state_dict()
    test_correct = 0
    test_total = 0
    model.load_state_dict(best_state)
    for epoch in range(10):
        for batch_index, (inputs, targets) in enumerate(testset):
            inputs = inputs.to(device)
            targets = targets.to(device)
            feature = model(inputs)
            correct = eval(input=feature, target=targets, n_support=parser.num_support_val)
            test_correct += correct
            test_total += parser.classes_per_it_val * parser.num_query_val
    test_acc = 100. * test_correct / test_total
    print('Testing Accuracy: {}'.format(test_acc))
示例#10
0
def train_model(train_dataset, train_num_each, val_dataset, val_num_each):
    num_train = len(train_dataset)
    num_val = len(val_dataset)

    train_useful_start_idx = get_useful_start_idx(sequence_length,
                                                  train_num_each)
    val_useful_start_idx = get_useful_start_idx(sequence_length, val_num_each)

    num_train_we_use = len(train_useful_start_idx) // num_gpu * num_gpu
    num_val_we_use = len(val_useful_start_idx) // num_gpu * num_gpu
    # num_train_we_use = 8000
    # num_val_we_use = 800

    train_we_use_start_idx = train_useful_start_idx[0:num_train_we_use]
    val_we_use_start_idx = val_useful_start_idx[0:num_val_we_use]

    #    np.random.seed(0)
    # np.random.shuffle(train_we_use_start_idx)
    train_idx = []
    for i in range(num_train_we_use):
        for j in range(sequence_length):
            train_idx.append(train_we_use_start_idx[i] + j)

    val_idx = []
    for i in range(num_val_we_use):
        for j in range(sequence_length):
            val_idx.append(val_we_use_start_idx[i] + j)

    num_train_all = len(train_idx)
    num_val_all = len(val_idx)
    print('num of train dataset: {:6d}'.format(num_train))
    print('num train start idx : {:6d}'.format(len(train_useful_start_idx)))
    print('last idx train start: {:6d}'.format(train_useful_start_idx[-1]))
    print('num of train we use : {:6d}'.format(num_train_we_use))
    print('num of all train use: {:6d}'.format(num_train_all))
    print('num of valid dataset: {:6d}'.format(num_val))
    print('num valid start idx : {:6d}'.format(len(val_useful_start_idx)))
    print('last idx valid start: {:6d}'.format(val_useful_start_idx[-1]))
    print('num of valid we use : {:6d}'.format(num_val_we_use))
    print('num of all valid use: {:6d}'.format(num_val_all))

    train_loader = DataLoader(train_dataset,
                              batch_size=train_batch_size,
                              sampler=train_idx,
                              num_workers=workers,
                              pin_memory=False)
    val_loader = DataLoader(val_dataset,
                            batch_size=val_batch_size,
                            sampler=val_idx,
                            num_workers=workers,
                            pin_memory=False)
    model = resnet_lstm_dp()
    if use_gpu:
        model = model.cuda()

    model = DataParallel(model)
    criterion = nn.CrossEntropyLoss(size_average=False)

    if multi_optim == 0:
        if optimizer_choice == 0:
            optimizer = optim.SGD(model.parameters(),
                                  lr=learning_rate,
                                  momentum=momentum,
                                  dampening=dampening,
                                  weight_decay=weight_decay,
                                  nesterov=use_nesterov)
            if sgd_adjust_lr == 0:
                exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                                       step_size=sgd_adjust_lr,
                                                       gamma=sgd_gamma)
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(
                    optimizer, 'min')
        elif optimizer_choice == 1:
            optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    elif multi_optim == 1:
        if optimizer_choice == 0:
            optimizer = optim.SGD([
                {
                    'params': model.module.share.parameters()
                },
                {
                    'params': model.module.lstm.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc.parameters(),
                    'lr': learning_rate
                },
            ],
                                  lr=learning_rate / 10,
                                  momentum=momentum,
                                  dampening=dampening,
                                  weight_decay=weight_decay,
                                  nesterov=use_nesterov)
            if sgd_adjust_lr == 0:
                exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                                       step_size=sgd_adjust_lr,
                                                       gamma=sgd_gamma)
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(
                    optimizer, 'min')
        elif optimizer_choice == 1:
            optimizer = optim.Adam([
                {
                    'params': model.module.share.parameters()
                },
                {
                    'params': model.module.lstm.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc.parameters(),
                    'lr': learning_rate
                },
            ],
                                   lr=learning_rate / 10)

    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_accuracy = 0.0
    correspond_train_acc = 0.0

    all_info = []
    all_train_accuracy = []
    all_train_loss = []
    all_val_accuracy = []
    all_val_loss = []

    for epoch in range(epochs):
        # np.random.seed(epoch)
        np.random.shuffle(train_we_use_start_idx)
        train_idx = []
        for i in range(num_train_we_use):
            for j in range(sequence_length):
                train_idx.append(train_we_use_start_idx[i] + j)

        train_loader = DataLoader(train_dataset,
                                  batch_size=train_batch_size,
                                  sampler=train_idx,
                                  num_workers=workers,
                                  pin_memory=False)

        model.train()
        train_loss = 0.0
        train_corrects = 0
        train_start_time = time.time()
        for data in train_loader:
            inputs, labels_1, labels_2 = data
            if use_gpu:
                inputs = Variable(inputs.cuda())
                labels = Variable(labels_2.cuda())
            else:
                inputs = Variable(inputs)
                labels = Variable(labels_2)
            optimizer.zero_grad()
            outputs = model.forward(inputs)
            _, preds = torch.max(outputs.data, 1)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.data[0]
            train_corrects += torch.sum(preds == labels.data)
        train_elapsed_time = time.time() - train_start_time
        train_accuracy = train_corrects / num_train_all
        train_average_loss = train_loss / num_train_all

        # begin eval
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        val_start_time = time.time()
        for data in val_loader:
            inputs, labels_1, labels_2 = data
            labels_2 = labels_2[(sequence_length - 1)::sequence_length]
            if use_gpu:
                inputs = Variable(inputs.cuda())
                labels = Variable(labels_2.cuda())
            else:
                inputs = Variable(inputs)
                labels = Variable(labels_2)

            if crop_type == 0 or crop_type == 1:
                outputs = model.forward(inputs)
            elif crop_type == 5:
                inputs = inputs.permute(1, 0, 2, 3, 4).contiguous()
                inputs = inputs.view(-1, 3, 224, 224)
                outputs = model.forward(inputs)
                outputs = outputs.view(5, -1, 7)
                outputs = torch.mean(outputs, 0)
            elif crop_type == 10:
                inputs = inputs.permute(1, 0, 2, 3, 4).contiguous()
                inputs = inputs.view(-1, 3, 224, 224)
                outputs = model.forward(inputs)
                outputs = outputs.view(10, -1, 7)
                outputs = torch.mean(outputs, 0)

            outputs = outputs[sequence_length - 1::sequence_length]

            _, preds = torch.max(outputs.data, 1)

            loss = criterion(outputs, labels)
            val_loss += loss.data[0]
            val_corrects += torch.sum(preds == labels.data)
        val_elapsed_time = time.time() - val_start_time
        val_accuracy = val_corrects / num_val_we_use
        val_average_loss = val_loss / num_val_we_use
        print('epoch: {:4d}'
              ' train in: {:2.0f}m{:2.0f}s'
              ' train loss: {:4.4f}'
              ' train accu: {:.4f}'
              ' valid in: {:2.0f}m{:2.0f}s'
              ' valid loss: {:4.4f}'
              ' valid accu: {:.4f}'.format(epoch, train_elapsed_time // 60,
                                           train_elapsed_time % 60,
                                           train_average_loss, train_accuracy,
                                           val_elapsed_time // 60,
                                           val_elapsed_time % 60,
                                           val_average_loss, val_accuracy))

        if optimizer_choice == 0:
            if sgd_adjust_lr == 0:
                exp_lr_scheduler.step()
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler.step(val_average_loss)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            correspond_train_acc = train_accuracy
            best_model_wts = copy.deepcopy(model.state_dict())
        if val_accuracy == best_val_accuracy:
            if train_accuracy > correspond_train_acc:
                correspond_train_acc = train_accuracy
                best_model_wts = copy.deepcopy(model.state_dict())
        all_train_loss.append(train_average_loss)
        all_train_accuracy.append(train_accuracy)
        all_val_loss.append(val_average_loss)
        all_val_accuracy.append(val_accuracy)

    print('best accuracy: {:.4f} cor train accu: {:.4f}'.format(
        best_val_accuracy, correspond_train_acc))

    save_val = int("{:4.0f}".format(best_val_accuracy * 10000))
    save_train = int("{:4.0f}".format(correspond_train_acc * 10000))
    model_name = "lstm" \
                 + "_epoch_" + str(epochs) \
                 + "_length_" + str(sequence_length) \
                 + "_opt_" + str(optimizer_choice) \
                 + "_mulopt_" + str(multi_optim) \
                 + "_flip_" + str(use_flip) \
                 + "_crop_" + str(crop_type) \
                 + "_batch_" + str(train_batch_size) \
                 + "_train_" + str(save_train) \
                 + "_val_" + str(save_val) \
                 + ".pth"

    torch.save(best_model_wts, model_name)

    all_info.append(all_train_accuracy)
    all_info.append(all_train_loss)
    all_info.append(all_val_accuracy)
    all_info.append(all_val_loss)

    record_name = "lstm" \
                  + "_epoch_" + str(epochs) \
                  + "_length_" + str(sequence_length) \
                  + "_opt_" + str(optimizer_choice) \
                  + "_mulopt_" + str(multi_optim) \
                  + "_flip_" + str(use_flip) \
                  + "_crop_" + str(crop_type) \
                  + "_batch_" + str(train_batch_size) \
                  + "_train_" + str(save_train) \
                  + "_val_" + str(save_val) \
                  + ".pkl"

    with open(record_name, 'wb') as f:
        pickle.dump(all_info, f)
    print()
示例#11
0
文件: train2.py 项目: urafi/reid
def main(args):

    manualSeed = random.randint(1, 100000)
    print("Random Seed: ", manualSeed)
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    cudnn.benchmark = True
    #cudnn.deterministic = False
    cudnn.enabled = True

    root = '//media/ekodirov/1002d198-cc12-4f27-aae2-fdef0f8cea56/Anvpersons/ReID_datasets/deep-person-reid-datasets/hhl/data/'

    train_source, num_classes = preprocess(root + 'market/bounding_box_train',
                                           relabel=True)
    gallery, _ = preprocess(root + 'market/bounding_box_test', relabel=False)
    query, _ = preprocess(root + 'market/query', relabel=False)

    marketTrain = Market('train', train_source,
                         root + 'market/bounding_box_train/', 'train',
                         args.height, args.width)
    galleryds = Market('val', gallery, root + 'market/bounding_box_test/',
                       'gallery', args.height, args.width)
    querds = Market('val', query, root + 'market/query/', 'query', args.height,
                    args.width)

    num_epochs = args.epochs
    train_batch_size = args.batch_size
    test_batch_size = 64
    train_loader = DataLoader(marketTrain,
                              batch_size=train_batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=False)
    query_loader = DataLoader(querds,
                              batch_size=train_batch_size,
                              shuffle=False,
                              num_workers=8,
                              pin_memory=False)
    gallery_loader = DataLoader(galleryds,
                                batch_size=train_batch_size,
                                shuffle=False,
                                num_workers=8,
                                pin_memory=False)

    reidNet = resnet50(pretrained=True, num_classes=num_classes)
    model = DataParallel(reidNet).cuda()

    # Optimizer
    if hasattr(model.module, 'base'):
        base_param_ids = set(map(id, model.module.base.parameters()))
        new_params = [
            p for p in model.parameters() if id(p) not in base_param_ids
        ]
        param_groups = [{
            'params': model.module.base.parameters(),
            'lr_mult': 0.1
        }, {
            'params': new_params,
            'lr_mult': 1.0
        }]
        print('Learning rate is set.')
    else:
        param_groups = model.parameters()
    optimiser = torch.optim.SGD(param_groups,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    #optimiser = torch.optim.Adam(model.parameters(), lr=lr, eps=.1)
    # optimiser = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)

    # Schedule learning rate
    step_size = args.step_size

    def adjust_lr(epoch):
        _lr = args.lr * (args.lr_factor**(epoch // step_size))
        print(_lr)
        for g in optimiser.param_groups:
            g['lr'] = _lr * g.get('lr_mult', 1)

    #checkpoint = torch.load('models_epoch/reidNet_10.pth')
    #model.load_state_dict(checkpoint['state_dict'])
    #optimiser.load_state_dict(checkpoint['optimizer'])

    criterion = torch.nn.CrossEntropyLoss(reduction='elementwise_mean').cuda()

    start_epoch = 0  #checkpoint['epoch'] + 1

    for epoch in range(start_epoch, num_epochs):
        adjust_lr(epoch)

        print("Starting Epoch [%d]" % (epoch))

        tloss = train(train_loader, model, optimiser, criterion)

        state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimiser.state_dict(),
        }

        evaluator = Evaluator(model)
        all = evaluator.evaluate(query_loader, gallery_loader, query, gallery,
                                 args.output_feature, args.rerank)

        with open('losses/rank1.txt', 'a') as the_file:
            the_file.write(str(all[0] * 100) + '\n')
        the_file.close()

        model_name = '/media/ekodirov/1002d198-cc12-4f27-aae2-fdef0f8cea56/reid/models_epoch/reidNet_' \
                     + str(epoch) + '_' + str(all[0] * 100)[:5] +'.pth'
        torch.save(state, model_name)
def main(args):
    # Select the hardware device to use for inference.
    if torch.cuda.is_available():
        device = torch.device('cuda', torch.cuda.current_device())
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    # Disable gradient calculations by default.
    torch.set_grad_enabled(False)

    # create checkpoint dir
    os.makedirs(args.checkpoint, exist_ok=True)

    if args.arch == 'hg1':
        model = hg1(pretrained=False)
    elif args.arch == 'hg2':
        model = hg2(pretrained=False)
    elif args.arch == 'hg8':
        model = hg8(pretrained=False)
    else:
        raise Exception('unrecognised model architecture: ' + args.arch)

    model = DataParallel(model).to(device)

    optimizer = RMSprop(model.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)

    best_acc = 0

    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_acc = checkpoint['best_acc']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'))
        logger.set_names(
            ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    # create data loader
    train_dataset = Mpii(args.image_path, is_train=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.train_batch,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    val_dataset = Mpii(args.image_path, is_train=False)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    # train and eval
    lr = args.lr
    for epoch in trange(args.start_epoch,
                        args.epochs,
                        desc='Overall',
                        ascii=True):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)

        # train for one epoch
        train_loss, train_acc = do_training_epoch(train_loader,
                                                  model,
                                                  device,
                                                  Mpii.DATA_INFO,
                                                  optimizer,
                                                  acc_joints=Mpii.ACC_JOINTS)

        # evaluate on validation set
        valid_loss, valid_acc, predictions = do_validation_epoch(
            val_loader,
            model,
            device,
            Mpii.DATA_INFO,
            False,
            acc_joints=Mpii.ACC_JOINTS)

        # print metrics
        tqdm.write(
            f'[{epoch + 1:3d}/{args.epochs:3d}] lr={lr:0.2e} '
            f'train_loss={train_loss:0.4f} train_acc={100 * train_acc:0.2f} '
            f'valid_loss={valid_loss:0.4f} valid_acc={100 * valid_acc:0.2f}')

        # append logger file
        logger.append(
            [epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])
        logger.plot_to_file(os.path.join(args.checkpoint, 'log.svg'),
                            ['Train Acc', 'Val Acc'])

        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            predictions,
            is_best,
            checkpoint=args.checkpoint,
            snapshot=args.snapshot)

    logger.close()
示例#13
0
class Solver():
    def __init__(self, config, channel_list):
        # Config - Model
        self.z_dim = config.z_dim
        self.channel_list = channel_list

        # Config - Training
        self.batch_size = config.batch_size
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.decay_ratio = config.decay_ratio
        self.decay_iter = config.decay_iter
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.n_critic = config.n_critic
        self.lambda_gp = config.lambda_gp
        self.max_iter = config.max_iter

        self.r1_iter = config.r1_iter
        self.r1_lambda = config.r1_lambda
        self.ppl_iter = config.ppl_iter
        self.ppl_lambda = config.ppl_lambda

        # Config - Test
        self.fixed_z = torch.randn(512, config.z_dim).to(dev)

        # Config - Path
        self.data_root = config.data_root
        self.log_root = config.log_root
        self.model_root = config.model_root
        self.sample_root = config.sample_root

        # Config - Miscellanceous
        self.print_loss_iter = config.print_loss_iter
        self.save_image_iter = config.save_image_iter
        self.save_parameter_iter = config.save_parameter_iter
        self.save_log_iter = config.save_log_iter

        self.writer = SummaryWriter(self.log_root)

    def build_model(self):
        self.G = Generator(channel_list=self.channel_list)
        self.G_ema = Generator(channel_list=self.channel_list)
        self.D = Discriminator(channel_list=self.channel_list)
        self.M = MappingNetwork(z_dim=self.z_dim)

        self.G = DataParallel(self.G).to(dev)
        self.G_ema = DataParallel(self.G_ema).to(dev)
        self.D = DataParallel(self.D).to(dev)
        self.M = DataParallel(self.M).to(dev)

        G_M_params = list(self.G.parameters()) + list(self.M.parameters())

        self.g_optimizer = torch.optim.Adam(params=G_M_params,
                                            lr=self.g_lr,
                                            betas=[self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(params=self.D.parameters(),
                                            lr=self.d_lr,
                                            betas=[self.beta1, self.beta2])

        self.g_scheduler = lr_scheduler.StepLR(self.g_optimizer,
                                               step_size=self.decay_iter,
                                               gamma=self.decay_ratio)
        self.d_scheduler = lr_scheduler.StepLR(self.d_optimizer,
                                               step_size=self.decay_iter,
                                               gamma=self.decay_ratio)

        print("Print model G, D")
        print(self.G)
        print(self.D)

    def load_model(self, pkl_path, channel_list):
        ckpt = torch.load(pkl_path)

        self.G = Generator(channel_list=channel_list)
        self.G_ema = Generator(channel_list=channel_list)
        self.D = Discriminator(channel_list=channel_list)
        self.M = MappingNetwork(z_dim=self.z_dim)

        self.G = DataParallel(self.G).to(dev)
        self.G_ema = DataParallel(self.G_ema).to(dev)
        self.D = DataParallel(self.D).to(dev)
        self.M = DataParallel(self.M).to(dev)

        self.G.load_state_dict(ckpt["G"])
        self.G_ema.load_state_dict(ckpt["G_ema"])
        self.D.load_state_dict(ckpt["D"])
        self.M.load_state_dict(ckpt["M"])

    def save_model(self, iters):
        file_name = 'ckpt_%d.pkl' % iters
        ckpt_path = os.path.join(self.model_root, file_name)
        ckpt = {
            'M': self.M.state_dict(),
            'G': self.G.state_dict(),
            'G_ema': self.G_ema.state_dict(),
            'D': self.D.state_dict()
        }
        torch.save(ckpt, ckpt_path)

    def save_img(self, iters, fixed_w):
        img_path = os.path.join(self.sample_root, "%d.png" % iters)
        with torch.no_grad():
            fixed_w = fixed_w[:self.batch_size * 2]
            dlatents_in = make_latents(fixed_w, self.batch_size,
                                       len(self.channel_list))
            generated_imgs, _ = self.G_ema(dlatents_in)
            save_image(
                make_grid(generated_imgs.cpu() / 2 + 1 / 2, nrow=4, padding=2),
                img_path)

    def reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def lr_update(self):
        self.g_scheduler.step()
        self.d_scheduler.step()

    def set_phase(self, mode="train"):
        if mode == "train":
            self.G.train()
            self.G_ema.train()
            self.D.train()
            self.M.train()

        elif mode == "test":
            self.G.eval()
            self.G_ema.eval()
            self.D.eval()
            self.M.eval()

    def exponential_moving_average(self, beta=0.999):
        with torch.no_grad():
            G_param_dict = dict(self.G.named_parameters())
            for name, g_ema_param in self.G_ema.named_parameters():
                g_param = G_param_dict[name]
                g_ema_param.copy_(beta * g_ema_param + (1. - beta) * g_param)

    def r1_regularization(self, real_pred, real_img):
        grad_real = torch.autograd.grad(outputs=real_pred.sum(),
                                        inputs=real_img,
                                        create_graph=True)[0]
        grad_penalty = grad_real.pow(2).view(grad_real.size(0),
                                             -1).sum(1).mean()
        return grad_penalty

    def path_length_regularization(self,
                                   fake_img,
                                   latents,
                                   mean_path_length,
                                   decay=0.01):
        noise = torch.randn_like(fake_img) / math.sqrt(
            fake_img.shape[2] * fake_img.shape[3])
        grad = torch.autograd.grad(outputs=(fake_img * noise).sum(),
                                   inputs=latents,
                                   create_graph=True)[0]
        path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
        path_mean = mean_path_length + decay * (path_lengths.mean() -
                                                mean_path_length)
        path_penalty = (path_lengths - path_mean).pow(2).mean()
        return path_penalty, path_mean.detach(), path_lengths

    def train(self):
        # build model
        self.build_model()
        loader = data_loader(self.data_root, self.batch_size, img_size=512)
        loader = iter(cycle(loader))
        mean_path_length = torch.tensor(0.0).to(dev)
        average_path_length = torch.tensor(0.0).to(dev)

        for iters in tqdm(range(self.max_iter + 1)):
            real_img = next(loader)
            real_img = real_img.to(dev)
            # ===============================================================#
            #                    1. Train the discriminator                  #
            # ===============================================================#
            self.set_phase(mode="train")
            self.reset_grad()

            # Compute loss with real images.
            d_real_out = self.D(real_img)
            d_loss_real = F.softplus(-d_real_out).mean()

            # Compute loss with face images.
            z = torch.randn(2 * self.batch_size, self.z_dim).to(dev)
            w = self.M(z)
            dlatents_in = make_latents(w, self.batch_size,
                                       len(self.channel_list))
            fake_img, _ = self.G(dlatents_in)
            d_fake_out = self.D(fake_img.detach())
            d_loss_fake = F.softplus(d_fake_out).mean()

            d_loss = d_loss_real + d_loss_fake

            if iters % self.r1_iter == 0:
                real_img.requires_grad = True
                d_real_out = self.D(real_img)
                r1_loss = self.r1_regularization(d_real_out, real_img)
                r1_loss = self.r1_lambda / 2 * r1_loss * self.r1_iter
                d_loss = d_loss + r1_loss

            d_loss.backward()
            self.d_optimizer.step()
            # ===============================================================#
            #                      2. Train the Generator                    #
            # ===============================================================#

            if (iters + 1) % self.n_critic == 0:
                self.reset_grad()

                # Compute loss with fake images.
                z = torch.randn(2 * self.batch_size, self.z_dim).to(dev)
                w = self.M(z)
                dlatents_in = make_latents(w, self.batch_size,
                                           len(self.channel_list))
                fake_img, _ = self.G(dlatents_in)
                d_fake_out = self.D(fake_img)
                g_loss = F.softplus(-d_fake_out).mean()

                if iters % self.ppl_iter == 0:
                    path_loss, mean_path_length, path_length = self.path_length_regularization(
                        fake_img, dlatents_in, mean_path_length)
                    path_loss = path_loss * self.ppl_iter * self.ppl_lambda
                    g_loss = g_loss + path_loss
                    mean_path_length = mean_path_length.mean()
                    average_path_length += mean_path_length.mean()

                # Backward and optimize.
                g_loss.backward()
                self.g_optimizer.step()

            # ===============================================================#
            #                   3. Save parameters and images                #
            # ===============================================================#
            # self.lr_update()
            torch.cuda.synchronize()
            self.set_phase(mode="test")
            self.exponential_moving_average()

            # Print total loss
            if iters % self.print_loss_iter == 0:
                print(
                    "Iter : [%d/%d], D_loss : [%.3f, %.3f, %.3f.], G_loss : %.3f, R1_reg : %.3f, "
                    "PPL_reg : %.3f, Path_length : %.3f" %
                    (iters, self.max_iter, d_loss.item(), d_loss_real.item(),
                     d_loss_fake.item(), g_loss.item(), r1_loss.item(),
                     path_loss.item(), mean_path_length.item()))

            # Save generated images.
            if iters % self.save_image_iter == 0:
                fixed_w = self.M(self.fixed_z)
                self.save_img(iters, fixed_w)

            # Save the G and D parameters.
            if iters % self.save_parameter_iter == 0:
                self.save_model(iters)

            # Save the logs on the tensorboard.
            if iters % self.save_log_iter == 0:
                self.writer.add_scalar('g_loss/g_loss', g_loss.item(), iters)
                self.writer.add_scalar('d_loss/d_loss_total', d_loss.item(),
                                       iters)
                self.writer.add_scalar('d_loss/d_loss_real',
                                       d_loss_real.item(), iters)
                self.writer.add_scalar('d_loss/d_loss_fake',
                                       d_loss_fake.item(), iters)
                self.writer.add_scalar('reg/r1_regularization', r1_loss.item(),
                                       iters)
                self.writer.add_scalar('reg/ppl_regularization',
                                       path_loss.item(), iters)

                self.writer.add_scalar('length/path_length',
                                       mean_path_length.item(), iters)
                self.writer.add_scalar(
                    'length/avg_path_length',
                    average_path_length.item() / (iters // self.ppl_iter + 1),
                    iters)
示例#14
0
def trainNucleusModel(
    model,
    checkpoint_path: str,
    data_loader: DataLoader,
    data_loader_test: DataLoader = None,
    n_gradient_updates=100000,
    effective_batch_size=2,
    print_freq=1,
    window_size=None,
    smoothing_window=10,
    test_evaluate_freq=5,
    test_maxDets=None,
    crop_inference_to_fov=False,
    optimizer_type='SGD',
    optimizer_params=None,
    lr_scheduler_type=None,
    lr_scheduler_params=None,
    loss_weights=None,
    n_testtime_augmentations=None,
    freeze_det_after=None,
):
    """"""
    # some basic checks
    assert smoothing_window <= len(data_loader) // effective_batch_size
    test_maxDets = test_maxDets or [1, 100, 300]
    assert len(test_maxDets) == 3
    n_testtime_augmentations = n_testtime_augmentations or [0]
    freeze_det_after = freeze_det_after or np.inf

    # make sure max loss weight is 1
    if loss_weights is not None:
        mxl = max(loss_weights.values())
        loss_weights = {k: v / mxl for k, v in loss_weights.items()}

    # train on the GPU or on the CPU, if a GPU is not available
    device = torch.device('cuda') if torch.cuda.is_available() else \
        torch.device('cpu')

    # NOTE:
    #  The torch data parallelism loads all images into ONE GPU
    #  (the "main" one), copies the model on all GPUs, distributes the data to
    #  them, passes forward mode (which INCLUDES the loss in faster/maskrcnn),
    #  collects the output from all GPUs into the "main" GPU, then
    #  backprops the loss in parallel for all GPUs. See the diagram here:
    #  https://blog.paperspace.com/pytorch-memory-multi-gpu-debugging/
    #  What this means is that:
    #  1- All data needs to fit in one GPU EVEN THOUGH everything is parallel
    #  2- There's an imbalanced load on the GPUs, with the "main" GPU
    #  (gpu0, usually) handling more work & utilizing more memory than others.
    #  See this discussion:
    #  https://discuss.pytorch.org/t/dataparallel-imbalanced-memory-usage/22551
    #  Using DistributedDataParallel may help a bit in speed, but it does not
    #  help with the distibuted load thing .. everything STILL needs to fit to
    #  one GPU. I'll just train one fold per GPU.

    # GPU parallelize if gpus are available. See:
    #   https://pytorch.org/tutorials/beginner/blitz/ ...
    #   data_parallel_tutorial.html#create-model-and-dataparallel
    if NGPUS > 1:

        print(f"Let's use {NGPUS} GPUs!")
        model = DataParallel(model)

        # # DistributedDataParallel is next to impossible to get to work!
        # # I tried this, among MANY others:
        # # https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/3
        # # I spent two full days trying to make it work, with all sorts of
        # # configs, with no luck. I give up!! It's not worth the time
        # # or effort; I'll just launch each fold independently
        # os.environ['MASTER_ADDR'] = 'localhost'
        # os.environ['MASTER_PORT'] = '29500'
        # torch.distributed.init_process_group(
        #     backend='nccl', init_method="env://", rank=0,
        #     world_size=torch.cuda.device_count(),  # use all visible gpus
        # )
        # model = DistributedDataParallel(model)

    # move model to the right device
    model.to(device)

    # # Mohamed: use half floats
    # if ISCUDA:
    #     model = model.to(torch.float16)

    # construct an optimizer
    optimizer = _get_optimizer(model=model,
                               optimizer_type=optimizer_type,
                               optimizer_params=optimizer_params)

    # load weights and optimizer state
    if os.path.exists(checkpoint_path):
        ckpt = load_ckp(checkpoint_path=checkpoint_path,
                        model=model,
                        optimizer=optimizer)
        model = ckpt['model']
        optimizer = ckpt['optimizer']
        start_epoch = ckpt['epoch']
        start_epoch += 1
    else:
        start_epoch = 1

    # learning rate scheduler
    if lr_scheduler_type is None:
        lr_scheduler = None
    elif lr_scheduler_type == 'step':
        lr_scheduler_params = lr_scheduler_params or {
            'step_size': 50,
            'gamma': 0.1,
        }
        grup = -1 if start_epoch == 1 else \
            (start_epoch - 1) * data_loader.batch_size
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
                                                       last_epoch=grup,
                                                       **lr_scheduler_params)
    else:
        raise NotImplementedError(f'Unknown lr_scheduler: {lr_scheduler_type}')

    # keep track of all batch losses if window size is not given
    # otherwise, only keep track of the last window_size batches
    window_size = window_size or len(data_loader)

    # let's train it for n_epochs

    grups_per_epoch = int(len(data_loader.dataset) / effective_batch_size)
    n_epochs = int(n_gradient_updates // grups_per_epoch)

    frozen_det = False

    for epoch in range(start_epoch, n_epochs + 1):

        # Maybe freeze detection, but keep training classification
        if (not frozen_det) and (
            (epoch - 1) * grups_per_epoch > freeze_det_after):
            model, optimizer = _freeze_detection(
                model=model,
                optimizer_type=optimizer_type,
                optimizer_params=optimizer_params,
            )
            frozen_det = True

        # train for one epoch
        trl = train_one_epoch(
            model=model,
            device=device,
            epoch=epoch,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            data_loader=data_loader,
            effective_batch_size=effective_batch_size,
            loss_weights=loss_weights,
            print_freq=print_freq,
            window_size=window_size,
        )
        trl = {
            ltype: v
            for ltype, v in trl.meters.items() if ltype.startswith('loss')
        }

        # evaluate on the test dataset
        tsls = []
        if (data_loader_test is not None) and ((epoch == n_epochs) or (
            (epoch - 1) % test_evaluate_freq == 0)):

            # get performance at all requested testtime augmentation levels
            for ntta in n_testtime_augmentations:
                model.n_testtime_augmentations = ntta
                tsls.append(
                    _evaluate_on_testing_set(
                        model=model,
                        data_loader_test=data_loader_test,
                        device=device,
                        test_maxDets=test_maxDets,
                        crop_inference_to_fov=crop_inference_to_fov))

        # save training loss
        _save_training_losses(trl=trl,
                              epoch=epoch,
                              smoothing_window=smoothing_window,
                              grups_per_epoch=grups_per_epoch,
                              checkpoint_path=checkpoint_path)

        # save testing metrics
        for i, tsl in enumerate(tsls):
            _save_testing_metrics(
                tsl=tsl,
                epoch=epoch,
                checkpoint_path=checkpoint_path,
                grup=epoch * grups_per_epoch,
                postfix=f'_{n_testtime_augmentations[i]}_TestTimeAugs',
            )

        # plot training and testing
        for i, ntta in enumerate(n_testtime_augmentations):
            pu.plot_accuracy_progress(
                checkpoint_path=checkpoint_path,
                postfix=f'_{ntta}_TestTimeAugs',
            )

        # create checkpoint and save
        print("*--- SAVING CHECKPOINT!! ---*")
        torch.save(model.state_dict(), f=checkpoint_path)
        torch.save(optimizer.state_dict(),
                   f=checkpoint_path.replace('.ckpt', '.optim'))
        meta = {'epoch': epoch}
        with open(checkpoint_path.replace('.ckpt', '.meta'), 'wb') as f:
            pickle.dump(meta, f)
示例#15
0
            if featureLs is None:
                featureLs = featureL
            else:
                featureLs = np.concatenate((featureLs, featureL), 0)
            if featureRs is None:
                featureRs = featureR
            else:
                featureRs = np.concatenate((featureRs, featureR), 0)

        result = {'fl': featureLs, 'fr': featureRs, 'fold': folds, 'flag': flags}
        # save tmp_result
        scipy.io.savemat('./result/tmp_result.mat', result)
        accs = evaluation_10_fold('./result/tmp_result.mat')
        _print('    ave: {:.4f}'.format(np.mean(accs) * 100))

    # save model
    if epoch % SAVE_FREQ == 0:
        msg = 'Saving checkpoint: {}'.format(epoch)
        _print(msg)
        if multi_gpus:
            net_state_dict = net.module.state_dict()
        else:
            net_state_dict = net.state_dict()
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        torch.save({
            'epoch': epoch,
            'net_state_dict': net_state_dict},
            os.path.join(save_dir, '%03d.ckpt' % epoch))
print('finishing training')
示例#16
0
文件: GAN.py 项目: kristofe/BMSG-GAN
class MSG_GAN:
    """ Unconditional TeacherGAN

        args:
            depth: depth of the GAN (will be used for each generator and discriminator)
            latent_size: latent size of the manifold used by the GAN
            use_eql: whether to use the equalized learning rate
            use_ema: whether to use exponential moving averages.
            ema_decay: value of ema decay. Used only if use_ema is True
            device: device to run the GAN on (GPU / CPU)
    """

    def __init__(self, depth=7, latent_size=512,
                 use_eql=True, use_ema=True, ema_decay=0.999,
                 device=th.device("cpu")):
        """ constructor for the class """
        from torch.nn import DataParallel

        self.gen = Generator(depth, latent_size, use_eql=use_eql).to(device)

        # Parallelize them if required:
        if device == th.device("cuda"):
            self.gen = DataParallel(self.gen)
            self.dis = Discriminator(depth, latent_size,
                                     use_eql=use_eql, gpu_parallelize=True).to(device)
        else:
            self.dis = Discriminator(depth, latent_size, use_eql=True).to(device)

        # state of the object
        self.use_ema = use_ema
        self.ema_decay = ema_decay
        self.use_eql = use_eql
        self.latent_size = latent_size
        self.depth = depth
        self.device = device

        if self.use_ema:
            from MSG_GAN.CustomLayers import update_average

            # create a shadow copy of the generator
            self.gen_shadow = copy.deepcopy(self.gen)

            # updater function:
            self.ema_updater = update_average

            # initialize the gen_shadow weights equal to the
            # weights of gen
            self.ema_updater(self.gen_shadow, self.gen, beta=0)

        # by default the generator and discriminator are in eval mode
        self.gen.eval()
        self.dis.eval()
        if self.use_ema:
            self.gen_shadow.eval()

    def generate_samples(self, num_samples):
        """
        generate samples using this gan
        :param num_samples: number of samples to be generated
        :return: generated samples tensor: list[ Tensor(B x H x W x C)]
        """
        noise = th.randn(num_samples, self.latent_size).to(self.device)
        generated_images = self.gen(noise)

        # reshape the generated images
        generated_images = list(map(lambda x: (x.detach().permute(0, 2, 3, 1) / 2) + 0.5,
                                    generated_images))

        return generated_images

    def optimize_discriminator(self, dis_optim, noise, real_batch, loss_fn):
        """
        performs one step of weight update on discriminator using the batch of data
        :param dis_optim: discriminator optimizer
        :param noise: input noise of sample generation
        :param real_batch: real samples batch
                           should contain a list of tensors at different scales
        :param loss_fn: loss function to be used (object of GANLoss)
        :return: current loss
        """

        # generate a batch of samples
        fake_samples = self.gen(noise)
        fake_samples = list(map(lambda x: x.detach(), fake_samples))

        loss = loss_fn.dis_loss(real_batch, fake_samples)

        # optimize discriminator
        dis_optim.zero_grad()
        loss.backward()
        dis_optim.step()

        return loss.item()

    def optimize_generator(self, gen_optim, noise, real_batch, loss_fn):
        """
        performs one step of weight update on generator using the batch of data
        :param gen_optim: generator optimizer
        :param noise: input noise of sample generation
        :param real_batch: real samples batch
                           should contain a list of tensors at different scales
        :param loss_fn: loss function to be used (object of GANLoss)
        :return: current loss
        """

        # generate a batch of samples
        fake_samples = self.gen(noise)

        loss = loss_fn.gen_loss(real_batch, fake_samples)

        # optimize discriminator
        gen_optim.zero_grad()
        loss.backward()
        gen_optim.step()

        # if self.use_ema is true, apply the moving average here:
        if self.use_ema:
            self.ema_updater(self.gen_shadow, self.gen, self.ema_decay)

        return loss.item()

    def create_grid(self, samples, img_files):
        """
        utility function to create a grid of GAN samples
        :param samples: generated samples for storing list[Tensors]
        :param img_files: list of names of files to write
        :return: None (saves multiple files)
        """
        from torchvision.utils import save_image
        from torch.nn.functional import interpolate
        from numpy import sqrt, power

        # dynamically adjust the colour of the images
        samples = [Generator.adjust_dynamic_range(sample) for sample in samples]

        # resize the samples to have same resolution:
        for i in range(len(samples)):
            samples[i] = interpolate(samples[i],
                                     scale_factor=power(2,
                                                        self.depth - 1 - i))
        # save the images:
        for sample, img_file in zip(samples, img_files):
            save_image(sample, img_file, nrow=int(sqrt(sample.shape[0])),
                       normalize=True, scale_each=True, padding=0)

    def train(self, data, gen_optim, dis_optim, loss_fn, normalize_latents=True,
              start=1, num_epochs=12, feedback_factor=10, checkpoint_factor=1,
              data_percentage=100, num_samples=36,
              log_dir=None, sample_dir="./samples",
              save_dir="./models"):
        """
        Method for training the network
        :param data: pytorch dataloader which iterates over images
        :param gen_optim: Optimizer for generator.
                          please wrap this inside a Scheduler if you want to
        :param dis_optim: Optimizer for discriminator.
                          please wrap this inside a Scheduler if you want to
        :param loss_fn: Object of GANLoss
        :param normalize_latents: whether to normalize the latent vectors during training
        :param start: starting epoch number
        :param num_epochs: total number of epochs to run for (ending epoch number)
                           note this is absolute and not relative to start
        :param feedback_factor: number of logs generated and samples generated
                                during training per epoch
        :param checkpoint_factor: save model after these many epochs
        :param data_percentage: amount of data to be used
        :param num_samples: number of samples to be drawn for feedback grid
        :param log_dir: path to directory for saving the loss.log file
        :param sample_dir: path to directory for saving generated samples' grids
        :param save_dir: path to directory for saving the trained models
        :return: None (writes multiple files to disk)
        """

        from torch.nn.functional import avg_pool2d

        # turn the generator and discriminator into train mode
        self.gen.train()
        self.dis.train()

        assert isinstance(gen_optim, th.optim.Optimizer), \
            "gen_optim is not an Optimizer"
        assert isinstance(dis_optim, th.optim.Optimizer), \
            "dis_optim is not an Optimizer"

        print("Starting the training process ... ")

        # create fixed_input for debugging
        fixed_input = th.randn(num_samples, self.latent_size).to(self.device)
        if normalize_latents:
            fixed_input = (fixed_input
                           / fixed_input.norm(dim=-1, keepdim=True)
                           * (self.latent_size ** 0.5))

        # create a global time counter
        global_time = time.time()
        global_step = 0

        for epoch in range(start, num_epochs + 1):
            start_time = timeit.default_timer()  # record time at the start of epoch

            print("\nEpoch: %d" % epoch)
            total_batches = len(iter(data))

            limit = int((data_percentage / 100) * total_batches)

            for (i, batch) in enumerate(data, 1):

                # extract current batch of data for training
                images = batch.to(self.device)
                extracted_batch_size = images.shape[0]

                # create a list of downsampled images from the real images:
                images = [images] + [avg_pool2d(images, int(np.power(2, i)))
                                     for i in range(1, self.depth)]
                images = list(reversed(images))

                # sample some random latent points
                gan_input = th.randn(
                    extracted_batch_size, self.latent_size).to(self.device)

                # normalize them if asked
                if normalize_latents:
                    gan_input = (gan_input
                                 / gan_input.norm(dim=-1, keepdim=True)
                                 * (self.latent_size ** 0.5))

                # optimize the discriminator:
                dis_loss = self.optimize_discriminator(dis_optim, gan_input,
                                                       images, loss_fn)

                # optimize the generator:
                gen_loss = self.optimize_generator(gen_optim, gan_input,
                                                   images, loss_fn)

                # provide a loss feedback
                if i % (int(limit / feedback_factor) + 1) == 0 or i == 1:     # Avoid div by 0 error on small training sets
                    elapsed = time.time() - global_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    print("Elapsed [%s] batch: %d  d_loss: %f  g_loss: %f"
                          % (elapsed, i, dis_loss, gen_loss))

                    # also write the losses to the log file:
                    if log_dir is not None:
                        log_file = os.path.join(log_dir, "loss.log")
                        os.makedirs(os.path.dirname(log_file), exist_ok=True)
                        with open(log_file, "a") as log:
                            log.write(str(global_step) + "\t" + str(dis_loss) +
                                      "\t" + str(gen_loss) + "\n")

                    # create a grid of samples and save it
                    reses = [str(int(np.power(2, dep))) + "_x_"
                             + str(int(np.power(2, dep)))
                             for dep in range(2, self.depth + 2)]
                    gen_img_files = [os.path.join(sample_dir, res, "gen_" +
                                                  str(epoch) + "_" +
                                                  str(i) + ".png")
                                     for res in reses]

                    # Make sure all the required directories exist
                    # otherwise make them
                    os.makedirs(sample_dir, exist_ok=True)
                    for gen_img_file in gen_img_files:
                        os.makedirs(os.path.dirname(gen_img_file), exist_ok=True)

                    dis_optim.zero_grad()
                    gen_optim.zero_grad()
                    with th.no_grad():
                        self.create_grid(
                            self.gen(fixed_input) if not self.use_ema
                            else self.gen_shadow(fixed_input),
                            gen_img_files)

                # increment the global_step:
                global_step += 1

                if i > limit:
                    break

            # calculate the time required for the epoch
            stop_time = timeit.default_timer()
            print("Time taken for epoch: %.3f secs" % (stop_time - start_time))

            if epoch % checkpoint_factor == 0 or epoch == 1 or epoch == num_epochs:
                os.makedirs(save_dir, exist_ok=True)
                gen_save_file = os.path.join(save_dir, "GAN_GEN_" + str(epoch) + ".pth")
                dis_save_file = os.path.join(save_dir, "GAN_DIS_" + str(epoch) + ".pth")
                gen_optim_save_file = os.path.join(save_dir,
                                                   "GAN_GEN_OPTIM_" + str(epoch) + ".pth")
                dis_optim_save_file = os.path.join(save_dir,
                                                   "GAN_DIS_OPTIM_" + str(epoch) + ".pth")

                th.save(self.gen.state_dict(), gen_save_file)
                th.save(self.dis.state_dict(), dis_save_file)
                th.save(gen_optim.state_dict(), gen_optim_save_file)
                th.save(dis_optim.state_dict(), dis_optim_save_file)

                if self.use_ema:
                    gen_shadow_save_file = os.path.join(save_dir, "GAN_GEN_SHADOW_"
                                                        + str(epoch) + ".pth")
                    th.save(self.gen_shadow.state_dict(), gen_shadow_save_file)

        print("Training completed ...")

        # return the generator and discriminator back to eval mode
        self.gen.eval()
        self.dis.eval()
示例#17
0
class DNNModelPytorch(Model):
    """DNN Model
    Parameters
    ----------
    input_dim : int
        input dimension
    output_dim : int
        output dimension
    layers : tuple
        layer sizes
    lr : float
        learning rate
    lr_decay : float
        learning rate decay
    lr_decay_steps : int
        learning rate decay steps
    optimizer : str
        optimizer name
    GPU : int
        the GPU ID used for training
    """
    def __init__(
        self,
        lr=0.001,
        max_steps=300,
        batch_size=2000,
        early_stop_rounds=50,
        eval_steps=20,
        lr_decay=0.96,
        lr_decay_steps=100,
        optimizer="gd",
        loss="mse",
        GPU=0,
        seed=None,
        weight_decay=0.0,
        data_parall=False,
        scheduler: Optional[Union[
            Callable]] = "default",  # when it is Callable, it accept one argument named optimizer
        init_model=None,
        eval_train_metric=False,
        pt_model_uri="qlib.contrib.model.pytorch_nn.Net",
        pt_model_kwargs={
            "input_dim": 360,
            "layers": (256, ),
        },
        valid_key=DataHandlerLP.DK_L,
        # TODO: Infer Key is a more reasonable key. But it requires more detailed processing on label processing
    ):
        # Set logger.
        self.logger = get_module_logger("DNNModelPytorch")
        self.logger.info("DNN pytorch version...")

        # set hyper-parameters.
        self.lr = lr
        self.max_steps = max_steps
        self.batch_size = batch_size
        self.early_stop_rounds = early_stop_rounds
        self.eval_steps = eval_steps
        self.lr_decay = lr_decay
        self.lr_decay_steps = lr_decay_steps
        self.optimizer = optimizer.lower()
        self.loss_type = loss
        if isinstance(GPU, str):
            self.device = torch.device(GPU)
        else:
            self.device = torch.device(
                "cuda:%d" %
                (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
        self.seed = seed
        self.weight_decay = weight_decay
        self.data_parall = data_parall
        self.eval_train_metric = eval_train_metric
        self.valid_key = valid_key

        self.best_step = None

        self.logger.info("DNN parameters setting:"
                         f"\nlr : {lr}"
                         f"\nmax_steps : {max_steps}"
                         f"\nbatch_size : {batch_size}"
                         f"\nearly_stop_rounds : {early_stop_rounds}"
                         f"\neval_steps : {eval_steps}"
                         f"\nlr_decay : {lr_decay}"
                         f"\nlr_decay_steps : {lr_decay_steps}"
                         f"\noptimizer : {optimizer}"
                         f"\nloss_type : {loss}"
                         f"\nseed : {seed}"
                         f"\ndevice : {self.device}"
                         f"\nuse_GPU : {self.use_gpu}"
                         f"\nweight_decay : {weight_decay}"
                         f"\nenable data parall : {self.data_parall}"
                         f"\npt_model_uri: {pt_model_uri}"
                         f"\npt_model_kwargs: {pt_model_kwargs}")

        if self.seed is not None:
            np.random.seed(self.seed)
            torch.manual_seed(self.seed)

        if loss not in {"mse", "binary"}:
            raise NotImplementedError("loss {} is not supported!".format(loss))
        self._scorer = mean_squared_error if loss == "mse" else roc_auc_score

        if init_model is None:
            self.dnn_model = init_instance_by_config({
                "class": pt_model_uri,
                "kwargs": pt_model_kwargs
            })

            if self.data_parall:
                self.dnn_model = DataParallel(self.dnn_model).to(self.device)
        else:
            self.dnn_model = init_model

        self.logger.info("model:\n{:}".format(self.dnn_model))
        self.logger.info("model size: {:.4f} MB".format(
            count_parameters(self.dnn_model)))

        if optimizer.lower() == "adam":
            self.train_optimizer = optim.Adam(self.dnn_model.parameters(),
                                              lr=self.lr,
                                              weight_decay=self.weight_decay)
        elif optimizer.lower() == "gd":
            self.train_optimizer = optim.SGD(self.dnn_model.parameters(),
                                             lr=self.lr,
                                             weight_decay=self.weight_decay)
        else:
            raise NotImplementedError(
                "optimizer {} is not supported!".format(optimizer))

        if scheduler == "default":
            # Reduce learning rate when loss has stopped decrease
            self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.train_optimizer,
                mode="min",
                factor=0.5,
                patience=10,
                verbose=True,
                threshold=0.0001,
                threshold_mode="rel",
                cooldown=0,
                min_lr=0.00001,
                eps=1e-08,
            )
        elif scheduler is None:
            self.scheduler = None
        else:
            self.scheduler = scheduler(optimizer=self.train_optimizer)

        self.fitted = False
        self.dnn_model.to(self.device)

    @property
    def use_gpu(self):
        return self.device != torch.device("cpu")

    def fit(
            self,
            dataset: DatasetH,
            evals_result=dict(),
            verbose=True,
            save_path=None,
            reweighter=None,
    ):
        has_valid = "valid" in dataset.segments
        segments = ["train", "valid"]
        vars = ["x", "y", "w"]
        all_df = defaultdict(
            dict)  # x_train, x_valid y_train, y_valid w_train, w_valid
        all_t = defaultdict(dict)  # tensors
        for seg in segments:
            if seg in dataset.segments:
                # df_train df_valid
                df = dataset.prepare(seg,
                                     col_set=["feature", "label"],
                                     data_key=self.valid_key
                                     if seg == "valid" else DataHandlerLP.DK_L)
                all_df["x"][seg] = df["feature"]
                all_df["y"][seg] = df["label"].copy(
                )  # We have to use copy to remove the reference to release mem
                if reweighter is None:
                    all_df["w"][seg] = pd.DataFrame(np.ones_like(
                        all_df["y"][seg].values),
                                                    index=df.index)
                elif isinstance(reweighter, Reweighter):
                    all_df["w"][seg] = pd.DataFrame(reweighter.reweight(df))
                else:
                    raise ValueError("Unsupported reweighter type.")

                # get tensors
                for v in vars:
                    all_t[v][seg] = torch.from_numpy(
                        all_df[v][seg].values).float()
                    # if seg == "valid": # accelerate the eval of validation
                    all_t[v][seg] = all_t[v][seg].to(
                        self.device)  # This will consume a lot of memory !!!!

                evals_result[seg] = []
                # free memory
                del df
                del all_df["x"]
                gc.collect()

        save_path = get_or_create_path(save_path)
        stop_steps = 0
        train_loss = 0
        best_loss = np.inf
        # train
        self.logger.info("training...")
        self.fitted = True
        # return
        # prepare training data
        train_num = all_t["y"]["train"].shape[0]

        for step in range(1, self.max_steps + 1):
            if stop_steps >= self.early_stop_rounds:
                if verbose:
                    self.logger.info("\tearly stop")
                break
            loss = AverageMeter()
            self.dnn_model.train()
            self.train_optimizer.zero_grad()
            choice = np.random.choice(train_num, self.batch_size)
            x_batch_auto = all_t["x"]["train"][choice].to(self.device)
            y_batch_auto = all_t["y"]["train"][choice].to(self.device)
            w_batch_auto = all_t["w"]["train"][choice].to(self.device)

            # forward
            preds = self.dnn_model(x_batch_auto)
            cur_loss = self.get_loss(preds, w_batch_auto, y_batch_auto,
                                     self.loss_type)
            cur_loss.backward()
            self.train_optimizer.step()
            loss.update(cur_loss.item())
            R.log_metrics(train_loss=loss.avg, step=step)

            # validation
            train_loss += loss.val
            # for evert `eval_steps` steps or at the last steps, we will evaluate the model.
            if step % self.eval_steps == 0 or step == self.max_steps:
                if has_valid:
                    stop_steps += 1
                    train_loss /= self.eval_steps

                    with torch.no_grad():
                        self.dnn_model.eval()

                        # forward
                        preds = self._nn_predict(all_t["x"]["valid"],
                                                 return_cpu=False)
                        cur_loss_val = self.get_loss(preds,
                                                     all_t["w"]["valid"],
                                                     all_t["y"]["valid"],
                                                     self.loss_type)
                        loss_val = cur_loss_val.item()
                        metric_val = (self.get_metric(
                            preds.reshape(-1), all_t["y"]["valid"].reshape(-1),
                            all_df["y"]
                            ["valid"].index).detach().cpu().numpy().item())
                        R.log_metrics(val_loss=loss_val, step=step)
                        R.log_metrics(val_metric=metric_val, step=step)

                        if self.eval_train_metric:
                            metric_train = (self.get_metric(
                                self._nn_predict(all_t["x"]["train"],
                                                 return_cpu=False),
                                all_t["y"]["train"].reshape(-1),
                                all_df["y"]["train"].index,
                            ).detach().cpu().numpy().item())
                            R.log_metrics(train_metric=metric_train, step=step)
                        else:
                            metric_train = np.nan
                    if verbose:
                        self.logger.info(
                            f"[Step {step}]: train_loss {train_loss:.6f}, valid_loss {loss_val:.6f}, train_metric {metric_train:.6f}, valid_metric {metric_val:.6f}"
                        )
                    evals_result["train"].append(train_loss)
                    evals_result["valid"].append(loss_val)
                    if loss_val < best_loss:
                        if verbose:
                            self.logger.info(
                                "\tvalid loss update from {:.6f} to {:.6f}, save checkpoint."
                                .format(best_loss, loss_val))
                        best_loss = loss_val
                        self.best_step = step
                        R.log_metrics(best_step=self.best_step, step=step)
                        stop_steps = 0
                        torch.save(self.dnn_model.state_dict(), save_path)
                    train_loss = 0
                    # update learning rate
                    if self.scheduler is not None:
                        auto_filter_kwargs(self.scheduler.step,
                                           warning=False)(metrics=cur_loss_val,
                                                          epoch=step)
                    R.log_metrics(lr=self.get_lr(), step=step)
                else:
                    # retraining mode
                    if self.scheduler is not None:
                        self.scheduler.step(epoch=step)

        if has_valid:
            # restore the optimal parameters after training
            self.dnn_model.load_state_dict(
                torch.load(save_path, map_location=self.device))
        if self.use_gpu:
            torch.cuda.empty_cache()

    def get_lr(self):
        assert len(self.train_optimizer.param_groups) == 1
        return self.train_optimizer.param_groups[0]["lr"]

    def get_loss(self, pred, w, target, loss_type):
        pred, w, target = pred.reshape(-1), w.reshape(-1), target.reshape(-1)
        if loss_type == "mse":
            sqr_loss = torch.mul(pred - target, pred - target)
            loss = torch.mul(sqr_loss, w).mean()
            return loss
        elif loss_type == "binary":
            loss = nn.BCEWithLogitsLoss(weight=w)
            return loss(pred, target)
        else:
            raise NotImplementedError(
                "loss {} is not supported!".format(loss_type))

    def get_metric(self, pred, target, index):
        # NOTE: the order of the index must follow <datetime, instrument> sorted order
        return -ICLoss()(pred, target, index)  # pylint: disable=E1130

    def _nn_predict(self, data, return_cpu=True):
        """Reusing predicting NN.
        Scenarios
        1) test inference (data may come from CPU and expect the output data is on CPU)
        2) evaluation on training (data may come from GPU)
        """
        if not isinstance(data, torch.Tensor):
            if isinstance(data, pd.DataFrame):
                data = data.values
            data = torch.Tensor(data)
        data = data.to(self.device)
        preds = []
        self.dnn_model.eval()
        with torch.no_grad():
            batch_size = 8096
            for i in range(0, len(data), batch_size):
                x = data[i:i + batch_size]
                preds.append(
                    self.dnn_model(x.to(self.device)).detach().reshape(-1))
        if return_cpu:
            preds = np.concatenate([pr.cpu().numpy() for pr in preds])
        else:
            preds = torch.cat(preds, axis=0)
        return preds

    def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
        if not self.fitted:
            raise ValueError("model is not fitted yet!")
        x_test_pd = dataset.prepare(segment,
                                    col_set="feature",
                                    data_key=DataHandlerLP.DK_I)
        preds = self._nn_predict(x_test_pd)
        return pd.Series(preds.reshape(-1), index=x_test_pd.index)

    def save(self, filename, **kwargs):
        with save_multiple_parts_file(filename) as model_dir:
            model_path = os.path.join(model_dir, os.path.split(model_dir)[-1])
            # Save model
            torch.save(self.dnn_model.state_dict(), model_path)

    def load(self, buffer, **kwargs):
        with unpack_archive_with_buffer(buffer) as model_dir:
            # Get model name
            _model_name = os.path.splitext(
                list(
                    filter(lambda x: x.startswith("model.bin"),
                           os.listdir(model_dir)))[0])[0]
            _model_path = os.path.join(model_dir, _model_name)
            # Load model
            self.dnn_model.load_state_dict(
                torch.load(_model_path, map_location=self.device))
        self.fitted = True
示例#18
0
class UNetTrainer(object):
    """UNet trainer"""
    def __init__(self,
                 start_epoch=0,
                 save_dir='',
                 resume="",
                 devices_num=2,
                 num_classes=2,
                 color_dim=1):
        self.net = UNet(color_dim=color_dim, num_classes=num_classes)
        self.start_epoch = start_epoch if start_epoch != 0 else 1
        self.save_dir = os.path.join('../models', save_dir)
        self.loss = CrossEntropyLoss()
        self.num_classes = num_classes
        if resume:
            checkpoint = torch.load(resume)
            if self.start_epoch == 0:
                self.start_epoch = checkpoint['epoch'] + 1
            if not self.save_dir:
                self.save_dir = checkpoint['save_dir']
            self.net.load_state_dict(checkpoint['state_dir'])
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        # self.net.cuda()
        # self.loss.cuda()
        if devices_num == 2:
            self.net = DataParallel(self.net, device_ids=[0, 1])
        # self.loss = DataParallel(self.loss, device_ids=[0, 1])

    def train(self,
              train_loader,
              val_loader,
              lr=0.001,
              weight_decay=1e-4,
              epochs=200,
              save_freq=10):
        self.logfile = os.path.join(self.save_dir, 'log')
        sys.stdout = Logger(self.logfile)
        self.epochs = epochs
        self.lr = lr
        optimizer = torch.optim.Adam(
            self.net.parameters(),
            # lr,
            # momentum=0.9,
            weight_decay=weight_decay)
        for epoch in range(self.start_epoch, epochs + 1):
            self.train_(train_loader, epoch, optimizer, save_freq)
            self.validate_(val_loader, epoch)

    def train_(self, data_loader, epoch, optimizer, save_freq=10):
        start_time = time.time()
        self.net.train()
        # lr = self.get_lr(epoch)
        # for param_group in optimizer.param_groups:
        # param_group['lr'] = lr
        metrics = []
        for i, (data, target) in enumerate(tqdm(data_loader)):
            data_t, target_t = data, target
            # data = Variable(data.cuda(async = True))
            # target = Variable(target.cuda(async = True))
            data = Variable(data)
            target = Variable(target)
            output = self.net(data)  # UNet输出结果
            output = output.transpose(1, 3).transpose(1, 2).contiguous().view(
                -1, self.num_classes)
            target = target.view(-1)
            loss_output = self.loss(output, target)
            optimizer.zero_grad()
            loss_output.backward()  # 反向传播Loss
            optimizer.step()
            loss_output = loss_output.data.item()  # Loss数值
            acc = accuracy(output, target)
            metrics.append([loss_output, acc])
            if i == 0:
                batch_size = data.size(0)
                _, output = output.data.max(
                    dim=1
                )  #  _为最大值,output 为output.data.max(按行)的最大值的索引,如果为0 ,则第一层的卷积出来的数大,反之为第二层
                output = output.view(batch_size, 1, 1, 320, 480).cpu()  # 预测结果图
                data_t = data_t[0, 0].unsqueeze(0).unsqueeze(0)  # 原img图
                target_t = target_t[0].unsqueeze(0)  # gt图
                t = torch.cat([output[0].float(), data_t, target_t.float()], 0)
                # 第一个参数为list,拼接3张图像
                # show_list = []
                # for j in range(10):
                #    show_list.append(data_t[j, 0].unsqueeze(0).unsqueeze(0))
                #    show_list.append(target_t[j].unsqueeze(0))
                #    show_list.append(output[j].float())
                #
                # t = torch.cat(show_list, 0)
                torchvision.utils.save_image(t,
                                             "../Try/temp/%02d_train.jpg" %
                                             epoch,
                                             nrow=3)
            # if i == 20:
            # break
        if epoch % save_freq == 0:
            if 'module' in dir(self.net):
                state_dict = self.net.module.state_dict()
            else:
                state_dict = self.net.state_dict()
            for key in state_dict.keys():
                state_dict[key] = state_dict[key].cpu()
            torch.save(
                {
                    'epoch': epoch,
                    'save_dir': self.save_dir,
                    'state_dir': state_dict
                }, os.path.join(self.save_dir, '%03d.ckpt' % epoch))
        end_time = time.time()
        metrics = np.asarray(metrics, np.float32)
        self.print_metrics(metrics, 'Train', end_time - start_time, epoch)

    def validate_(self, data_loader, epoch):
        start_time = time.time()
        self.net.eval()
        metrics = []
        for i, (data, target) in enumerate(tqdm(data_loader)):
            data_t, target_t = data, target
            # data = Variable(data.cuda(async = True), volatile = True)
            # target = Variable(target.cuda(async = True), volatile = True)
            data = Variable(data, requires_grad=False)
            target = Variable(target, requires_grad=False)
            output = self.net(data)
            output = output.transpose(1, 3).transpose(1, 2).contiguous().view(
                -1, self.num_classes)
            target = target.view(-1)
            loss_output = self.loss(output, target)
            loss_output = loss_output.data.item()
            acc = accuracy(output, target)
            metrics.append([loss_output, acc])
            if i == 0:
                batch_size = data.size(0)
                _, output = output.data.max(dim=1)
                output = output.view(batch_size, 1, 1, 320, 480).cpu()
                data_t = data_t[0, 0].unsqueeze(0).unsqueeze(0)
                target_t = target_t[0].unsqueeze(0)
                t = torch.cat([output[0].float(), data_t, target_t.float()], 0)
                # show_list = []
                # for j in range(10):
                #   show_list.append(data_t[j, 0].unsqueeze(0).unsqueeze(0))
                #   show_list.append(target_t[j].unsqueeze(0))
                #   show_list.append(output[j].float())
                #
                # t = torch.cat(show_list, 0)
                torchvision.utils.save_image(t,
                                             "../Try/temp/%02d_train.jpg" %
                                             epoch,
                                             nrow=3)
            # if i == 10:
            #   break
        end_time = time.time()
        metrics = np.asarray(metrics, np.float32)
        self.print_metrics(metrics, 'Validation', end_time - start_time)

    def print_metrics(self, metrics, phase, time, epoch=-1):
        """metrics: [loss, acc]
        """
        if epoch != -1:
            print("Epoch: {}".format(epoch), )
        print(phase, )
        print('loss %2.4f, accuracy %2.4f, time %2.2f' %
              (np.mean(metrics[:, 0]), np.mean(metrics[:, 1]), time))
        if phase != 'Train':
            print()

    def get_lr(self, epoch):
        if epoch <= self.epochs * 0.5:
            lr = self.lr
        elif epoch <= self.epochs * 0.8:
            lr = 0.1 * self.lr
        else:
            lr = 0.01 * self.lr
        return lr

    def save_py_files(self, path):
        """copy .py files in exps dir, cfgs dir and current dir into
            save_dir, and keep the files structure
        """
        # exps dir
        pyfiles = [f for f in os.listdir(path) if f.endswith('.py')]
        path = "/".join(path.split('/')[-2:])
        exp_save_path = os.path.join(self.save_dir, path)
        mkdir(exp_save_path)
        for f in pyfiles:
            shutil.copy(os.path.join(path, f), os.path.join(exp_save_path, f))
        # current dir
        pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
        for f in pyfiles:
            shutil.copy(f, os.path.join(self.save_dir, f))
        # cfgs dir
        shutil.copytree('./cfgs', os.path.join(self.save_dir, 'cfgs'))
示例#19
0
文件: core.py 项目: Duplums/pynet
class Base(Observable):
    """ Class to perform classification.
    """
    def __init__(self, optimizer_name="Adam", learning_rate=1e-3,
                 loss_name="NLLLoss", metrics=None, use_cuda=False,
                 pretrained=None, freeze_until_layer=None, load_optimizer=True, use_multi_gpu=True,
                 **kwargs):
        """ Class instantiation.

        Observers will be notified, allowed signals are:
        - 'before_epoch'
        - 'after_epoch'

        Parameters
        ----------
        optimizer_name: str, default 'Adam'
            the name of the optimizer: see 'torch.optim' for a description
            of available optimizer.
        learning_rate: float, default 1e-3
            the optimizer learning rate.
        loss_name: str, default 'NLLLoss'
            the name of the loss: see 'torch.nn' for a description
            of available loss.
        metrics: list of str
            a list of extra metrics that will be computed.
        use_cuda: bool, default False
            wether to use GPU or CPU.
        pretrained: path, default None
            path to the pretrained model or weights.
        load_optimizer: boolean, default True
            if pretrained is set, whether to also load the optimizer's weights or not
        use_multi_gpu: boolean, default True
            if several GPUs are available, use them during forward/backward pass
        kwargs: dict
            specify directly a custom 'model', 'optimizer' or 'loss'. Can also
            be used to set specific optimizer parameters.
        """
        super().__init__(
            signals=["before_epoch", "after_epoch", "after_iteration"])
        self.optimizer = kwargs.get("optimizer")
        self.logger = logging.getLogger("pynet")
        self.loss = kwargs.get("loss")
        self.device = torch.device("cuda" if use_cuda else "cpu")
        for name in ("optimizer", "loss"):
            if name in kwargs:
                kwargs.pop(name)
        if "model" in kwargs:
            self.model = kwargs.pop("model")
        if self.optimizer is None:
            if optimizer_name in dir(torch.optim):
                self.optimizer = getattr(torch.optim, optimizer_name)(
                    self.model.parameters(),
                    lr=learning_rate,
                    **kwargs)
            else:
                raise ValueError("Optimizer '{0}' uknown: check available "
                                 "optimizer in 'pytorch.optim'.")
        if self.loss is None:
            if loss_name not in dir(torch.nn):
                raise ValueError("Loss '{0}' uknown: check available loss in "
                                 "'pytorch.nn'.")
            self.loss = getattr(torch.nn, loss_name)()
        self.metrics = {}
        for name in (metrics or []):
            if name not in mmetrics.METRICS:
                raise ValueError("Metric '{0}' not yet supported: you can try "
                                 "to fill the 'METRICS' factory, or ask for "
                                 "some help!".format(name))
            self.metrics[name] = mmetrics.METRICS[name]
        if use_cuda and not torch.cuda.is_available():
            raise ValueError("No GPU found: unset 'use_cuda' parameter.")
        if pretrained is not None:
            checkpoint = None
            try:
                checkpoint = torch.load(pretrained, map_location=lambda storage, loc: storage)
            except BaseException as e:
                self.logger.error('Impossible to load the checkpoint: %s' % str(e))
            if checkpoint is not None:
                if hasattr(checkpoint, "state_dict"):
                    self.model.load_state_dict(checkpoint.state_dict())
                elif isinstance(checkpoint, dict):
                    if "model" in checkpoint:
                        try:
                            ## TODO: Quick fix to modify
                            for key in list(checkpoint['model'].keys()):
                                if key.replace('module.', '') != key:
                                    checkpoint['model'][key.replace('module.', '')] = checkpoint['model'][key]
                                    del(checkpoint['model'][key])
                            #####
                            unexpected= self.model.load_state_dict(checkpoint["model"], strict=False)
                            self.logger.info('Model loading info: {}'.format(unexpected))
                            self.logger.info('Model loaded')
                        except BaseException as e:
                            self.logger.error('Error while loading the model\'s weights: %s' % str(e))
                            raise ValueError("")
                    if "optimizer" in checkpoint:
                        if load_optimizer:
                            try:
                                self.optimizer.load_state_dict(checkpoint["optimizer"])
                                for state in self.optimizer.state.values():
                                    for k, v in state.items():
                                        if torch.is_tensor(v):
                                            state[k] = v.to(self.device)
                            except BaseException as e:
                                self.logger.error('Error while loading the optimizer\'s weights: %s' % str(e))
                        else:
                            self.logger.warning("The optimizer's weights are not restored ! ")
                else:
                    self.model.load_state_dict(checkpoint)
        if freeze_until_layer is not None:
            freeze_until(self.model, freeze_until_layer)

        if use_multi_gpu and torch.cuda.device_count() > 1:
            self.model = DataParallel(self.model)

        self.model = self.model.to(self.device)

    def training(self, manager: AbstractDataManager, nb_epochs: int, checkpointdir=None,
                 fold_index=None, epoch_index=None,
                 scheduler=None, with_validation=True, with_visualization=False,
                 nb_epochs_per_saving=1, exp_name=None, standard_optim=True,
                 gpu_time_profiling=False, **kwargs_train):
        """ Train the model.

        Parameters
        ----------
        manager: a pynet DataManager
            a manager containing the train and validation data.
        nb_epochs: int, default 100
            the number of epochs.
        checkpointdir: str, default None
            a destination folder where intermediate models/historues will be
            saved.
        fold_index: int or [int] default None
            the index(es) of the fold(s) to use for the training, default use all the
            available folds.
        epoch_index: int, default None
            the iteration where to start the counting from
        scheduler: torch.optim.lr_scheduler, default None
            a scheduler used to reduce the learning rate.
        with_validation: bool, default True
            if set use the validation dataset.
        with_visualization: bool, default False,
            whether it uses a visualizer that will plot the losses/metrics/images in a WebApp framework
            during the training process
        nb_epochs_per_saving: int, default 1,
            the number of epochs after which the model+optimizer's parameters are saved
        exp_name: str, default None
            the experience name that will be launched
        Returns
        -------
        train_history, valid_history: History
            the train/validation history.
        """

        train_history = History(name="Train_%s"%(exp_name or ""))
        if with_validation is not None:
            valid_history = History(name="Validation_%s"%(exp_name or ""))
        else:
            valid_history = None
        train_visualizer, valid_visualizer = None, None
        if with_visualization:
            train_visualizer = Visualizer(train_history)
            if with_validation:
                valid_visualizer = Visualizer(valid_history, offset_win=10)
        print(self.loss)
        print(self.optimizer)
        folds = range(manager.get_nb_folds())
        if fold_index is not None:
            if isinstance(fold_index, int):
                folds = [fold_index]
            elif isinstance(fold_index, list):
                folds = fold_index
        if epoch_index is None:
            epoch_index = 0
        init_optim_state = deepcopy(self.optimizer.state_dict())
        init_model_state = deepcopy(self.model.state_dict())
        if scheduler is not None:
            init_scheduler_state = deepcopy(scheduler.state_dict())
        for fold in folds:
            # Initialize everything before optimizing on a new fold
            self.optimizer.load_state_dict(init_optim_state)
            self.model.load_state_dict(init_model_state)
            if scheduler is not None:
                scheduler.load_state_dict(init_scheduler_state)
            loader = manager.get_dataloader(
                train=True,
                validation=True,
                fold_index=fold)
            for epoch in range(nb_epochs):
                self.notify_observers("before_epoch", epoch=epoch, fold=fold)
                loss, values = self.train(loader.train, train_visualizer, fold, epoch,
                                          standard_optim=standard_optim,
                                          gpu_time_profiling=gpu_time_profiling, **kwargs_train)

                train_history.log((fold, epoch+epoch_index), loss=loss, **values)
                train_history.summary()
                if scheduler is not None:
                    scheduler.step()
                    print('Scheduler lr: {}'.format(scheduler.get_lr()), flush=True)
                    print('Optimizer lr: %f'%self.optimizer.param_groups[0]['lr'], flush=True)
                if checkpointdir is not None and (epoch % nb_epochs_per_saving == 0 or epoch == nb_epochs-1) \
                        and epoch > 0:
                    checkpoint(
                        model=self.model,
                        epoch=epoch+epoch_index,
                        fold=fold,
                        outdir=checkpointdir,
                        name=exp_name,
                        optimizer=self.optimizer)
                    train_history.save(
                        outdir=checkpointdir,
                        epoch=epoch+epoch_index,
                        fold=fold)
                if with_validation:
                    _, _, _, loss, values = self.test(loader.validation,
                                                      standard_optim=standard_optim, **kwargs_train)
                    valid_history.log((fold, epoch+epoch_index), validation_loss=loss, **values)
                    valid_history.summary()
                    if valid_visualizer is not None:
                        valid_visualizer.refresh_current_metrics()
                    if checkpointdir is not None and (epoch % nb_epochs_per_saving == 0 or epoch == nb_epochs-1) \
                            and epoch > 0:
                        valid_history.save(
                            outdir=checkpointdir,
                            epoch=epoch+epoch_index,
                            fold=fold)
                self.notify_observers("after_epoch", epoch=epoch, fold=fold)
        return train_history, valid_history

    def train(self, loader, visualizer=None, fold=None, epoch=None, standard_optim=True,
              gpu_time_profiling=False, **kwargs):
        """ Train the model on the trained data.

        Parameters
        ----------
        loader: a pytorch Dataloader

        Returns
        -------
        loss: float
            the value of the loss function.
        values: dict
            the values of the metrics.
        """

        self.model.train()
        nb_batch = len(loader)
        pbar = tqdm(total=nb_batch, desc="Mini-Batch")

        values = {}
        iteration = 0
        if gpu_time_profiling:
            gpu_time_per_batch = []
        if not standard_optim:
            loss, values = self.model(iter(loader), pbar=pbar, visualizer=visualizer)
        else:
            losses = []
            y_pred = []
            y_true = []
            for dataitem in loader:
                pbar.update()
                inputs = dataitem.inputs
                if isinstance(inputs, torch.Tensor):
                    inputs = inputs.to(self.device)
                list_targets = []
                _targets = []
                for item in (dataitem.outputs, dataitem.labels):
                    if item is not None:
                        _targets.append(item.to(self.device))
                if len(_targets) == 1:
                    _targets = _targets[0]
                list_targets.append(_targets)
    
                self.optimizer.zero_grad()
                if gpu_time_profiling:
                    start_event = torch.cuda.Event(enable_timing=True)
                    end_event = torch.cuda.Event(enable_timing=True)
                    start_event.record()

                outputs = self.model(inputs)

                if gpu_time_profiling:
                    end_event.record()
                    torch.cuda.synchronize()
                    elapsed_time_ms = start_event.elapsed_time(end_event)
                    gpu_time_per_batch.append(elapsed_time_ms)

                batch_loss = self.loss(outputs, *list_targets)
                batch_loss.backward()
                self.optimizer.step()

                losses.append(float(batch_loss))
                y_pred.extend(outputs.detach().cpu().numpy())
                y_true.extend(list_targets[0].detach().cpu().numpy())
    
                aux_losses = (self.model.get_aux_losses() if hasattr(self.model, 'get_aux_losses') else dict())
                aux_losses.update(self.loss.get_aux_losses() if hasattr(self.loss, 'get_aux_losses') else dict())
    
                for name, aux_loss in aux_losses.items():
                    if name not in values:
                        values[name] = 0
                    values[name] += float(aux_loss) / nb_batch
                if iteration % 10 == 0:
                    if visualizer is not None:
                        visualizer.refresh_current_metrics()
                        if hasattr(self.model, "get_current_visuals"):
                            visuals = self.model.get_current_visuals()
                            visualizer.display_images(visuals, ncols=3)
                iteration += 1
            loss = np.mean(losses)
            for name, metric in self.metrics.items():
                if name not in values:
                    values[name] = 0
                values[name] = float(metric(torch.tensor(y_pred), torch.tensor(y_true)))

        if gpu_time_profiling:
            self.logger.info("GPU Time Statistics over 1 epoch:\n\t- {:.2f} +/- {:.2f} ms calling model(data) per batch"
                                                              "\n\t- {:.2f} ms total time over 1 epoch ({} batches)".format(
                np.mean(gpu_time_per_batch), np.std(gpu_time_per_batch), np.sum(gpu_time_per_batch), nb_batch))
        pbar.close()
        return loss, values

    def testing(self, loader: DataLoader, with_logit=False, predict=False, with_visuals=False,
                saving_dir=None, exp_name=None, standard_optim=True, **kwargs):
        """ Evaluate the model.

        Parameters
        ----------
        loader: a pytorch DataLoader
        with_logit: bool, default False
            apply a softmax to the result.
        predict: bool, default False
            take the argmax over the channels.
        with_visuals: bool, default False
            returns the visuals got from the model
        Returns
        -------
        y: array-like
            the predicted data.
        X: array-like
            the input data.
        y_true: array-like
            the true data if available.
        loss: float
            the value of the loss function if true data availble.
        values: dict
            the values of the metrics if true data availble.
        """
        if with_visuals:
            y, y_true, X, loss, values, visuals = self.test(
                loader.test, with_logit=with_logit, predict=predict, with_visuals=with_visuals,
                standard_optim=standard_optim)
        else:
            y, y_true, X, loss, values = self.test(
                loader.test, with_logit=with_logit, predict=predict, with_visuals=with_visuals,
                standard_optim=standard_optim)

        if saving_dir is not None:
            with open(os.path.join(saving_dir, (exp_name or 'test')+'.pkl'), 'wb') as f:
                pickle.dump({'y_pred': y, 'y_true': y_true, 'loss': loss, 'metrics': values}, f)
        
        if with_visuals:
            return y, X, y_true, loss, values, visuals

        return y, X, y_true, loss, values

    def test(self, loader, with_logit=False, predict=False, with_visuals=False, standard_optim=True):
        """ Evaluate the model on the test or validation data.

        Parameter
        ---------
        loader: a pytorch Dataset
            the data loader.
        with_logit: bool, default False
            apply a softmax to the result.
        predict: bool, default False
            take the argmax over the channels.

        Returns
        -------
        y: array-like
            the predicted data.
        y_true: array-like
            the true data
        X: array_like
            the input data
        loss: float
            the value of the loss function.
        values: dict
            the values of the metrics.
        """

        self.model.eval()
        nb_batch = len(loader)
        pbar = tqdm(total=nb_batch, desc="Mini-Batch")
        loss = 0
        values = {}
        visuals = []

        with torch.no_grad():
            y, y_true, X = [], [], []
            if not standard_optim:
                loss, values, y, y_true, X = self.model(iter(loader), pbar=pbar)
            else:
                for dataitem in loader:
                    pbar.update()
                    inputs = dataitem.inputs
                    if isinstance(inputs, torch.Tensor):
                        inputs = inputs.to(self.device)
                    list_targets = []
                    targets = []
                    for item in (dataitem.outputs, dataitem.labels):
                        if item is not None:
                            targets.append(item.to(self.device))
                            y_true.extend(item.cpu().detach().numpy())
                    if len(targets) == 1:
                        targets = targets[0]
                    elif len(targets) == 0:
                        targets = None
                    if targets is not None:
                        list_targets.append(targets)

                    outputs = self.model(inputs)
                    if with_visuals:
                        visuals.append(self.model.get_current_visuals())
                    if len(list_targets) > 0:
                        batch_loss = self.loss(outputs, *list_targets)
                        loss += float(batch_loss) / nb_batch

                    y.extend(outputs.cpu().detach().numpy())

                    if isinstance(inputs, torch.Tensor):
                        X.extend(inputs.cpu().detach().numpy())

                    aux_losses = (self.model.get_aux_losses() if hasattr(self.model, 'get_aux_losses') else dict())
                    aux_losses.update(self.loss.get_aux_losses() if hasattr(self.loss, 'get_aux_losses') else dict())
                    for name, aux_loss in aux_losses.items():
                        name += " on validation set"
                        if name not in values:
                            values[name] = 0
                        values[name] += aux_loss / nb_batch
                        
                # Now computes the metrics with (y, y_true)
                for name, metric in self.metrics.items():
                    name += " on validation set"
                    values[name] = metric(torch.tensor(y), torch.tensor(y_true))
            pbar.close()
            
            if len(visuals) > 0:
                visuals = np.concatenate(visuals, axis=0)
            try:
                if with_logit:
                    y = func.softmax(torch.tensor(y), dim=1).detach().cpu().numpy()
                if predict:
                    y = np.argmax(y, axis=1)
            except Exception as e:
                print(e)
        if with_visuals:
            return y, y_true, X, loss, values, visuals
        return y, y_true, X, loss, values

    def MC_test(self, loader,  MC=50):
        """ Evaluate the model on the test or validation data by using a Monte-Carlo sampling.

        Parameter
        ---------
        loader: a pytorch Dataset
            the data loader.
        MC: int, default 50
            nb of times to perform a feed-forward per input

        Returns
        -------
        y: array-like dims (n_samples, MC, ...) where ... is the dims of the network's output
            the predicted data.
        y_true: array-like dims (n_samples, MC, ...) where ... is the dims of the network's output
            the true data
        """
        self.model.eval()
        nb_batch = len(loader)
        pbar = tqdm(total=nb_batch, desc="Mini-Batch")

        with torch.no_grad():
            y, y_true = [], []
            for dataitem in loader:
                pbar.update()
                inputs = dataitem.inputs
                if isinstance(inputs, torch.Tensor):
                    inputs = inputs.to(self.device)
                current_y, current_y_true = [], []
                for _ in range(MC):
                    for item in (dataitem.outputs, dataitem.labels):
                        if item is not None:
                            current_y_true.append(item.cpu().detach().numpy())

                    outputs = self.model(inputs)
                    current_y.append(outputs.cpu().detach().numpy())
                y.extend(np.array(current_y).swapaxes(0, 1))
                y_true.extend(np.array(current_y_true).swapaxes(0, 1))
        pbar.close()

        return np.array(y), np.array(y_true)
示例#20
0
def train_model(train_dataset, train_num_each, val_dataset, val_num_each):
    num_train = len(train_dataset)
    num_val = len(val_dataset)

    train_useful_start_idx = get_useful_start_idx(sequence_length,
                                                  train_num_each)

    val_useful_start_idx = get_useful_start_idx(sequence_length, val_num_each)

    num_train_we_use = len(train_useful_start_idx) // num_gpu * num_gpu
    num_val_we_use = len(val_useful_start_idx) // num_gpu * num_gpu
    # num_train_we_use = 800
    # num_val_we_use = 80

    train_we_use_start_idx = train_useful_start_idx[0:num_train_we_use]
    val_we_use_start_idx = val_useful_start_idx[0:num_val_we_use]

    train_idx = []
    for i in range(num_train_we_use):
        for j in range(sequence_length):
            train_idx.append(train_we_use_start_idx[i] + j)

    val_idx = []
    for i in range(num_val_we_use):
        for j in range(sequence_length):
            val_idx.append(val_we_use_start_idx[i] + j)

    num_train_all = len(train_idx)
    num_val_all = len(val_idx)

    print('num train start idx : {:6d}'.format(len(train_useful_start_idx)))
    print('last idx train start: {:6d}'.format(train_useful_start_idx[-1]))
    print('num of train dataset: {:6d}'.format(num_train))
    print('num of train we use : {:6d}'.format(num_train_we_use))
    print('num of all train use: {:6d}'.format(num_train_all))
    print('num valid start idx : {:6d}'.format(len(val_useful_start_idx)))
    print('last idx valid start: {:6d}'.format(val_useful_start_idx[-1]))
    print('num of valid dataset: {:6d}'.format(num_val))
    print('num of valid we use : {:6d}'.format(num_val_we_use))
    print('num of all valid use: {:6d}'.format(num_val_all))

    train_loader = DataLoader(train_dataset,
                              batch_size=train_batch_size,
                              sampler=train_idx,
                              num_workers=workers,
                              pin_memory=False)
    val_loader = DataLoader(val_dataset,
                            batch_size=val_batch_size,
                            sampler=val_idx,
                            num_workers=workers,
                            pin_memory=False)

    model_old = multi_lstm()
    model_old = DataParallel(model_old)
    model_old.load_state_dict(
        torch.load(
            "cnn_lstm_epoch_25_length_10_opt_1_mulopt_1_flip_0_crop_1_batch_400_train1_9997_train2_9982_val1_9744_val2_8876.pth"
        ))

    model = multi_lstm_p2t()
    model.share = model_old.module.share
    model.lstm = model_old.module.lstm
    model.fc = model_old.module.fc
    model.fc2 = model_old.module.fc2

    model = DataParallel(model)
    for param in model.module.fc_p2t.parameters():
        param.requires_grad = False
    model.module.fc_p2t.load_state_dict(
        torch.load(
            "fc_epoch_25_length_4_opt_1_mulopt_1_flip_0_crop_1_batch_800_train1_9951_train2_9713_val1_9686_val2_7867_p2t.pth"
        ))

    if use_gpu:
        model = model.cuda()
        model.module.fc_p2t = model.module.fc_p2t.cuda()

    criterion_1 = nn.BCEWithLogitsLoss(size_average=False)
    criterion_2 = nn.CrossEntropyLoss(size_average=False)
    criterion_3 = nn.KLDivLoss(size_average=False)
    sigmoid_cuda = nn.Sigmoid().cuda()

    if multi_optim == 0:
        if optimizer_choice == 0:
            optimizer = optim.SGD([{
                'params': model.module.share.parameters()
            }, {
                'params': model.module.lstm.parameters(),
            }, {
                'params': model.module.fc.parameters()
            }, {
                'params': model.module.fc2.parameters()
            }],
                                  lr=learning_rate,
                                  momentum=momentum,
                                  dampening=dampening,
                                  weight_decay=weight_decay,
                                  nesterov=use_nesterov)
            if sgd_adjust_lr == 0:
                exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                                       step_size=sgd_step,
                                                       gamma=sgd_gamma)
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(
                    optimizer, 'min')
        elif optimizer_choice == 1:
            optimizer = optim.Adam([{
                'params': model.module.share.parameters()
            }, {
                'params': model.module.lstm.parameters(),
            }, {
                'params': model.module.fc.parameters()
            }, {
                'params': model.module.fc2.parameters()
            }],
                                   lr=learning_rate)
    elif multi_optim == 1:
        if optimizer_choice == 0:
            optimizer = optim.SGD([{
                'params': model.module.share.parameters()
            }, {
                'params': model.module.lstm.parameters(),
                'lr': learning_rate
            }, {
                'params': model.module.fc.parameters(),
                'lr': learning_rate
            }],
                                  lr=learning_rate / 10,
                                  momentum=momentum,
                                  dampening=dampening,
                                  weight_decay=weight_decay,
                                  nesterov=use_nesterov)
            if sgd_adjust_lr == 0:
                exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                                       step_size=sgd_step,
                                                       gamma=sgd_gamma)
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(
                    optimizer, 'min')
        elif optimizer_choice == 1:
            optimizer = optim.Adam([{
                'params': model.module.share.parameters()
            }, {
                'params': model.module.lstm.parameters(),
                'lr': learning_rate
            }, {
                'params': model.module.fc.parameters(),
                'lr': learning_rate
            }, {
                'params': model.module.fc2.parameters(),
                'lr': learning_rate
            }],
                                   lr=learning_rate / 10)

    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_accuracy_1 = 0.0
    best_val_accuracy_2 = 0.0
    correspond_train_acc_1 = 0.0
    correspond_train_acc_2 = 0.0

    # 要存储2个train的准确率 2个valid的准确率 3个train 3个loss的loss, 一共10个数据要记录
    record_np = np.zeros([epochs, 12])

    for epoch in range(epochs):
        # np.random.seed(epoch)
        np.random.shuffle(train_we_use_start_idx)
        train_idx = []
        for i in range(num_train_we_use):
            for j in range(sequence_length):
                train_idx.append(train_we_use_start_idx[i] + j)

        train_loader = DataLoader(train_dataset,
                                  batch_size=train_batch_size,
                                  sampler=train_idx,
                                  num_workers=workers,
                                  pin_memory=False)

        model.train()
        train_loss_1 = 0.0
        train_loss_2 = 0.0
        train_loss_3 = 0.0
        train_corrects_1 = 0
        train_corrects_2 = 0
        train_corrects_3 = 0

        train_start_time = time.time()
        for data in train_loader:
            inputs, labels_1, labels_2 = data
            if use_gpu:
                inputs = Variable(inputs.cuda())
                labels_1 = Variable(labels_1.cuda())
                labels_2 = Variable(labels_2.cuda())
            else:
                inputs = Variable(inputs)
                labels_1 = Variable(labels_1)
                labels_2 = Variable(labels_2)

            optimizer.zero_grad()

            outputs_1, outputs_2, outputs_3 = model.forward(inputs)

            _, preds_2 = torch.max(outputs_2.data, 1)
            train_corrects_2 += torch.sum(preds_2 == labels_2.data)

            sig_output_1 = sigmoid_cuda(outputs_1)
            sig_output_3 = sigmoid_cuda(outputs_3)

            sig_average = (sig_output_1.data + sig_output_3.data) / 2

            preds_1 = torch.cuda.ByteTensor(sig_output_1.data > 0.5)
            preds_1 = preds_1.long()
            train_corrects_1 += torch.sum(preds_1 == labels_1.data)

            preds_3 = torch.cuda.ByteTensor(sig_average > 0.5)
            preds_3 = preds_3.long()
            train_corrects_3 += torch.sum(preds_3 == labels_1.data)

            labels_1 = Variable(labels_1.data.float())
            loss_1 = criterion_1(outputs_1, labels_1)
            loss_2 = criterion_2(outputs_2, labels_2)

            sig_output_3 = Variable(sig_output_3.data, requires_grad=False)
            loss_3 = torch.abs(criterion_3(sig_output_1, sig_output_3))
            loss = loss_1 + loss_2 + loss_3 * alpha
            loss.backward()
            optimizer.step()

            train_loss_1 += loss_1.data[0]
            train_loss_2 += loss_2.data[0]
            train_loss_3 += loss_3.data[0]

        train_elapsed_time = time.time() - train_start_time
        train_accuracy_1 = train_corrects_1 / num_train_all / 7
        train_accuracy_2 = train_corrects_2 / num_train_all
        train_accuracy_3 = train_corrects_3 / num_train_all / 7
        train_average_loss_1 = train_loss_1 / num_train_all / 7
        train_average_loss_2 = train_loss_2 / num_train_all
        train_average_loss_3 = train_loss_3 / num_train_all

        # begin eval

        model.eval()
        val_loss_1 = 0.0
        val_loss_2 = 0.0
        val_loss_3 = 0.0
        val_corrects_1 = 0
        val_corrects_2 = 0
        val_corrects_3 = 0

        val_start_time = time.time()
        for data in val_loader:
            inputs, labels_1, labels_2 = data
            labels_2 = labels_2[(sequence_length - 1)::sequence_length]
            if use_gpu:
                inputs = Variable(inputs.cuda(), volatile=True)
                labels_1 = Variable(labels_1.cuda(), volatile=True)
                labels_2 = Variable(labels_2.cuda(), volatile=True)
            else:
                inputs = Variable(inputs, volatile=True)
                labels_1 = Variable(labels_1, volatile=True)
                labels_2 = Variable(labels_2, volatile=True)

            outputs_1, outputs_2, outputs_3 = model.forward(inputs)
            outputs_2 = outputs_2[(sequence_length - 1)::sequence_length]
            _, preds_2 = torch.max(outputs_2.data, 1)
            val_corrects_2 += torch.sum(preds_2 == labels_2.data)

            sig_output_1 = sigmoid_cuda(outputs_1)
            sig_output_3 = sigmoid_cuda(outputs_3)

            sig_average = (sig_output_1.data + sig_output_3.data) / 2

            preds_1 = torch.cuda.ByteTensor(sig_output_1.data > 0.5)
            preds_1 = preds_1.long()
            val_corrects_1 += torch.sum(preds_1 == labels_1.data)

            preds_3 = torch.cuda.ByteTensor(sig_average > 0.5)
            preds_3 = preds_3.long()
            val_corrects_3 += torch.sum(preds_3 == labels_1.data)

            labels_1 = Variable(labels_1.data.float())
            loss_1 = criterion_1(outputs_1, labels_1)
            loss_2 = criterion_2(outputs_2, labels_2)

            sig_output_3 = Variable(sig_output_3.data, requires_grad=False)
            loss_3 = torch.abs(criterion_3(sig_output_1, sig_output_3))

            val_loss_1 += loss_1.data[0]
            val_loss_2 += loss_2.data[0]
            val_loss_3 += loss_3.data[0]

        val_elapsed_time = time.time() - val_start_time
        val_accuracy_1 = val_corrects_1 / (num_val_all * 7)
        val_accuracy_2 = val_corrects_2 / num_val_we_use
        val_accuracy_3 = val_corrects_3 / (num_val_all * 7)
        val_average_loss_1 = val_loss_1 / (num_val_all * 7)
        val_average_loss_2 = val_loss_2 / num_val_we_use
        val_average_loss_3 = val_loss_3 / num_val_all

        print('epoch: {:3d}'
              ' train time: {:2.0f}m{:2.0f}s'
              ' train accu_1: {:.4f}'
              ' train accu_3: {:.4f}'
              ' train accu_2: {:.4f}'
              ' train loss_1: {:4.4f}'
              ' train loss_2: {:4.4f}'
              ' train loss_3: {:4.4f}'.format(
                  epoch, train_elapsed_time // 60, train_elapsed_time % 60,
                  train_accuracy_1, train_accuracy_3, train_accuracy_2,
                  train_average_loss_1, train_average_loss_2,
                  train_average_loss_3))
        print('epoch: {:3d}'
              ' valid time: {:2.0f}m{:2.0f}s'
              ' valid accu_1: {:.4f}'
              ' valid accu_3: {:.4f}'
              ' valid accu_2: {:.4f}'
              ' valid loss_1: {:4.4f}'
              ' valid loss_2: {:4.4f}'
              ' valid loss_3: {:4.4f}'.format(
                  epoch, val_elapsed_time // 60, val_elapsed_time % 60,
                  val_accuracy_1, val_accuracy_3, val_accuracy_2,
                  val_average_loss_1, val_average_loss_2, val_average_loss_3))

        if optimizer_choice == 0:
            if sgd_adjust_lr == 0:
                exp_lr_scheduler.step()
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler.step(val_average_loss_1 + val_average_loss_2 +
                                      alpha * val_average_loss_3)

        if val_accuracy_2 > best_val_accuracy_2 and val_accuracy_1 > 0.95:
            best_val_accuracy_2 = val_accuracy_2
            best_val_accuracy_1 = val_accuracy_1
            correspond_train_acc_1 = train_accuracy_1
            correspond_train_acc_2 = train_accuracy_2
            best_model_wts = copy.deepcopy(model.state_dict())
        elif val_accuracy_2 == best_val_accuracy_2 and val_accuracy_1 > 0.95:
            if val_accuracy_1 > best_val_accuracy_1:
                correspond_train_acc_1 = train_accuracy_1
                correspond_train_acc_2 = train_accuracy_2
                best_model_wts = copy.deepcopy(model.state_dict())
            elif val_accuracy_1 == best_val_accuracy_1:
                if train_accuracy_2 > correspond_train_acc_2:
                    correspond_train_acc_2 = train_accuracy_2
                    correspond_train_acc_1 = train_accuracy_1
                    best_model_wts = copy.deepcopy(model.state_dict())
                elif train_accuracy_2 == correspond_train_acc_2:
                    if train_accuracy_1 > best_val_accuracy_1:
                        correspond_train_acc_1 = train_accuracy_1
                        best_model_wts = copy.deepcopy(model.state_dict())

        if val_accuracy_2 > 0.885:
            save_val_1 = int("{:4.0f}".format(val_accuracy_1 * 10000))
            save_val_2 = int("{:4.0f}".format(val_accuracy_2 * 10000))
            save_train_1 = int("{:4.0f}".format(train_accuracy_1 * 10000))
            save_train_2 = int("{:4.0f}".format(train_accuracy_2 * 10000))
            public_name = "cnn_lstm_p2t" \
                          + "_epoch_" + str(epochs) \
                          + "_length_" + str(sequence_length) \
                          + "_opt_" + str(optimizer_choice) \
                          + "_mulopt_" + str(multi_optim) \
                          + "_flip_" + str(use_flip) \
                          + "_crop_" + str(crop_type) \
                          + "_batch_" + str(train_batch_size) \
                          + "_train1_" + str(save_train_1) \
                          + "_train2_" + str(save_train_2) \
                          + "_val1_" + str(save_val_1) \
                          + "_val2_" + str(save_val_2)
            model_name = public_name + ".pth"
            torch.save(best_model_wts, model_name)

        record_np[epoch, 0] = train_accuracy_1
        record_np[epoch, 1] = train_accuracy_3
        record_np[epoch, 2] = train_accuracy_2
        record_np[epoch, 3] = train_average_loss_1
        record_np[epoch, 4] = train_average_loss_2
        record_np[epoch, 5] = train_average_loss_3

        record_np[epoch, 6] = val_accuracy_1
        record_np[epoch, 7] = val_accuracy_3
        record_np[epoch, 7] = val_accuracy_2
        record_np[epoch, 9] = val_average_loss_1
        record_np[epoch, 10] = val_average_loss_2
        record_np[epoch, 11] = val_average_loss_3

    print('best accuracy_1: {:.4f} cor train accu_1: {:.4f}'.format(
        best_val_accuracy_1, correspond_train_acc_1))
    print('best accuracy_2: {:.4f} cor train accu_2: {:.4f}'.format(
        best_val_accuracy_2, correspond_train_acc_2))

    # save_val_1 = int("{:4.0f}".format(best_val_accuracy_1 * 10000))
    # save_val_2 = int("{:4.0f}".format(best_val_accuracy_2 * 10000))
    # save_train_1 = int("{:4.0f}".format(correspond_train_acc_1 * 10000))
    # save_train_2 = int("{:4.0f}".format(correspond_train_acc_2 * 10000))
    # public_name = "cnn_lstm_p2t" \
    #               + "_epoch_" + str(epochs) \
    #               + "_length_" + str(sequence_length) \
    #               + "_opt_" + str(optimizer_choice) \
    #               + "_mulopt_" + str(multi_optim) \
    #               + "_flip_" + str(use_flip) \
    #               + "_crop_" + str(crop_type) \
    #               + "_batch_" + str(train_batch_size) \
    #               + "_train1_" + str(save_train_1) \
    #               + "_train2_" + str(save_train_2) \
    #               + "_val1_" + str(save_val_1) \
    #               + "_val2_" + str(save_val_2)
    # model_name = public_name + ".pth"
    # torch.save(best_model_wts, model_name)

    record_name = public_name + ".npy"
    np.save(record_name, record_np)
示例#21
0

if __name__ == '__main__':
    opt = Config()
    if opt.backbone == 'resnet18':
        model = resnet_face18(opt.use_se)
    elif opt.backbone == 'resnet34':
        model = resnet34()
    elif opt.backbone == 'resnet50':
        model = resnet50()
    metric_fc = Arcface()
    model = DataParallel(model)
    metric_fc = DataParallel(metric_fc)
    device = torch.device('cpu')

    model_dict = model.state_dict()
    pretrained_dict = torch.load("./resnet18_0.pth", map_location=device)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    metric_fc_dict = metric_fc.state_dict()
    pretrained_dict1 = torch.load("./checkpoints/arcface_196.pth", map_location=device)
    pretrained_dict1 = {k: v for k, v in pretrained_dict1.items() if k in metric_fc_dict}
    metric_fc_dict.update(pretrained_dict1)
    metric_fc.load_state_dict(metric_fc_dict)

    # 1.创建套接字 socket
    tcp_sever_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    # 2.绑定本地信息 bind
    tcp_sever_socket.bind(("", 90))  # 7890))
示例#22
0
class SMSG_GAN:
    """ Unconditional SMSG_GAN

        args:
            depth: depth of the GAN (will be used for each generator and discriminator)
            latent_size: latent size of the manifold used by the GAN
            device: device to run the GAN on (GPU / CPU)
    """

    def __init__(self, depth=7, latent_size=512, device=th.device("cpu")):
        """ constructor for the class """
        from torch.nn import DataParallel

        # Create the Generator and the Discriminator
        self.gen = Generator(depth, latent_size).to(device)
        self.dis = Discriminator(depth, latent_size).to(device)

        if device == th.device("cuda"):  # apply the data parallel if device is GPU
            self.gen = DataParallel(Generator(depth, latent_size))
            self.dis = DataParallel(Discriminator(depth, latent_size))

        # state of the object
        self.latent_size = latent_size
        self.depth = depth
        self.device = device

        # by default the generator and discriminator are in eval mode
        self.gen.eval()
        self.dis.eval()

    def generate_samples(self, num_samples):
        """
        generate samples using this gan
        :param num_samples: number of samples to be generated
        :return: generated samples tensors: list[ Tensor(B x H x W x C)]
                 generated attn-map tensors: list[ Tensor(B x H x W x H x W)]
        """
        noise = th.randn(num_samples, self.latent_size).to(self.device)
        generated_images, attention_maps = self.gen(noise)

        # reshape the generated images
        generated_images = list(map(lambda x: (x.detach().permute(0, 2, 3, 1) / 2) + 0.5,
                                    generated_images))

        attention_maps = list(map(lambda x: x.detach(), attention_maps))

        return generated_images, attention_maps

    def optimize_discriminator(self, dis_optim, noise, real_batch, loss_fn):
        """
        performs one step of weight update on discriminator using the batch of data
        :param dis_optim: discriminator optimizer
        :param noise: input noise of sample generation
        :param real_batch: real samples batch
                           should contain a list of tensors at different scales
        :param loss_fn: loss function to be used (object of GANLoss)
        :return: current loss
        """

        # generate a batch of samples
        fake_samples, _ = self.gen(noise)
        fake_samples = list(map(lambda x: x.detach(), fake_samples))

        loss = loss_fn.dis_loss(real_batch, fake_samples)

        # optimize discriminator
        dis_optim.zero_grad()
        loss.backward()
        dis_optim.step()

        return loss.item()

    def optimize_generator(self, gen_optim, noise, real_batch, loss_fn):
        """
        performs one step of weight update on generator using the batch of data
        :param gen_optim: generator optimizer
        :param noise: input noise of sample generation
        :param real_batch: real samples batch
                           should contain a list of tensors at different scales
        :param loss_fn: loss function to be used (object of GANLoss)
        :return: current loss
        """

        # generate a batch of samples
        fake_samples, _ = self.gen(noise)

        loss = loss_fn.gen_loss(real_batch, fake_samples)

        # optimize discriminator
        gen_optim.zero_grad()
        loss.backward()
        gen_optim.step()

        return loss.item()

    @staticmethod
    def create_grid(samples, img_files):
        """
        utility function to create a grid of GAN samples
        :param samples: generated samples for storing list[Tensors]
        :param img_files: list of names of files to write
        :return: None (saves multiple files)
        """
        from torchvision.utils import save_image
        from numpy import sqrt

        samples = list(map(lambda x: th.clamp((x.detach() / 2) + 0.5, min=0, max=1),
                           samples))

        # save the images:
        for sample, img_file in zip(samples, img_files):
            save_image(sample, img_file, nrow=int(sqrt(sample.shape[0])))

    def train(self, data, gen_optim, dis_optim, loss_fn,
              start=1, num_epochs=12, feedback_factor=10, checkpoint_factor=1,
              data_percentage=100, num_samples=64,
              log_dir=None, sample_dir="./samples",
              save_dir="./models"):
        """
        method to train the SMSG-GAN network
        :param data: object of pytorch dataloader which provides iterator to the data
        :param gen_optim: optimizer for the generator parameters
        :param dis_optim: optimizer for discriminator parameters
        :param loss_fn: object of GANLoss (defines the loss function)
        :param start: starting epoch number
        :param num_epochs: ending epoch number
        :param feedback_factor: number of samples (logs) generated per epoch
        :param checkpoint_factor: model saved after these many epochs
        :param data_percentage: amount of data to be used for training
        :param num_samples: number of samples in the generated sample grid
                            (preferably a perfect square number)
        :param log_dir: path to the directory for saving the loss.log file
        :param sample_dir: path to the directory for saving the generated samples
        :param save_dir: path to the directory for saving trained models
        :return: None (saves model on disk)
        """

        from torch.nn.functional import avg_pool2d

        # turn the generator and discriminator into train mode
        self.gen.train()
        self.dis.train()

        assert isinstance(gen_optim, th.optim.Optimizer), \
            "gen_optim is not an Optimizer"
        assert isinstance(dis_optim, th.optim.Optimizer), \
            "dis_optim is not an Optimizer"

        print("Starting the training process ... ")

        # create fixed_input for debugging
        fixed_input = th.randn(num_samples, self.latent_size).to(self.device)

        # create a global time counter
        global_time = time.time()

        for epoch in range(start, num_epochs + 1):
            start = timeit.default_timer()  # record time at the start of epoch

            print("\nEpoch: %d" % epoch)
            total_batches = len(iter(data))

            limit = int((data_percentage / 100) * total_batches)

            for (i, batch) in enumerate(data, 1):

                # extract current batch of data for training
                images = batch.to(self.device)
                extracted_batch_size = images.shape[0]

                # create a list of downsampled images from the real images:
                images = [images] + [avg_pool2d(images, int(np.power(2, i)))
                                     for i in range(1, self.depth)]
                images = list(reversed(images))

                gan_input = th.randn(
                    extracted_batch_size, self.latent_size).to(self.device)

                # optimize the discriminator:
                dis_loss = self.optimize_discriminator(dis_optim, gan_input,
                                                       images, loss_fn)

                # optimize the generator:
                # resample from the latent noise
                gan_input = th.randn(
                    extracted_batch_size, self.latent_size).to(self.device)
                gen_loss = self.optimize_generator(gen_optim, gan_input,
                                                   images, loss_fn)

                # provide a loss feedback
                if i % int(limit / feedback_factor) == 0 or i == 1:
                    elapsed = time.time() - global_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    print("\nElapsed [%s] batch: %d  d_loss: %f  g_loss: %f"
                          % (elapsed, i, dis_loss, gen_loss))
                    print("Generator_gammas:", self.gen.module.get_gammas())
                    print("Discriminator_gammas:", self.dis.module.get_gammas())

                    # also write the losses to the log file:
                    if log_dir is not None:
                        log_file = os.path.join(log_dir, "loss.log")
                        os.makedirs(os.path.dirname(log_file), exist_ok=True)
                        with open(log_file, "a") as log:
                            log.write(str(dis_loss) + "\t" + str(gen_loss) + "\n")

                    # create a grid of samples and save it
                    reses = [str(int(np.power(2, dep))) + "_x_"
                             + str(int(np.power(2, dep)))
                             for dep in range(2, self.depth + 2)]
                    gen_img_files = [os.path.join(sample_dir, res, "gen_" +
                                                  str(epoch) + "_" +
                                                  str(i) + ".png")
                                     for res in reses]

                    # Make sure all the required directories exist
                    # otherwise make them
                    os.makedirs(sample_dir, exist_ok=True)
                    for gen_img_file in gen_img_files:
                        os.makedirs(os.path.dirname(gen_img_file), exist_ok=True)

                    self.create_grid(self.gen(fixed_input)[0], gen_img_files)

                if i > limit:
                    break

            # calculate the time required for the epoch
            stop = timeit.default_timer()
            print("Time taken for epoch: %.3f secs" % (stop - start))

            if epoch % checkpoint_factor == 0 or epoch == 1 or epoch == num_epochs:
                os.makedirs(save_dir, exist_ok=True)
                gen_save_file = os.path.join(save_dir, "GAN_GEN_" + str(epoch) + ".pth")
                dis_save_file = os.path.join(save_dir, "GAN_DIS_" + str(epoch) + ".pth")

                th.save(self.gen.state_dict(), gen_save_file)
                th.save(self.dis.state_dict(), dis_save_file)

        print("Training completed ...")

        # return the generator and discriminator back to eval mode
        self.gen.eval()
        self.dis.eval()
示例#23
0
def train(args):
    # gpu init
    multi_gpus = False
    if len(args.gpus.split(',')) > 1:
        multi_gpus = True
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # log init
    save_dir = os.path.join(args.save_dir, datetime.now().date().strftime('%Y%m%d'))
    if not os.path.exists(save_dir):
        #raise NameError('model dir exists!')
        os.makedirs(save_dir)
    logging = init_log(save_dir)
    _print = logging.info
    # summary(net.to(config.device), (3,112,112))
    #define tranform
    transform = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
    ])
    net = EfficientNet.from_name('efficientnet-b0', num_classes=2)

    # validation dataset
    trainset = ANTI(train_root="/mnt/sda3/data/FASD", file_list = "train.txt", transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size = 2,
                                             shuffle=True, num_workers=8, drop_last=False)

    # define optimizers for different layer
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer_ft = optim.SGD([
        {'params': net.parameters(), 'weight_decay': 5e-4},
    ], lr=0.001, momentum=0.9, nesterov=True)

    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft, milestones= [6, 10, 30], gamma=0.1)
    if multi_gpus:
        net = DataParallel(net).to(device)
    else:
        net = net.to(device)

    total_iters = 1
    vis = Visualizer(env= "effiction")

    for epoch in range(1, args.total_epoch + 1):
        exp_lr_scheduler.step()
        _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch))
        net.train()
        since = time.time()
        for data in trainloader:
            img, label = data[0].to(device), data[1].to(device)
            optimizer_ft.zero_grad()
            raw_logits = net(img)
            total_loss = criterion(raw_logits, label)
            total_loss.backward()
            optimizer_ft.step()
            # print train information
            if total_iters % 200 == 0:
                # current training accuracy
                _, predict = torch.max(raw_logits.data, 1)
                total = label.size(0)
                correct = (np.array(predict) == np.array(label.data)).sum()
                time_cur = (time.time() - since) / 100
                since = time.time()
                vis.plot_curves({'softmax loss': total_loss.item()}, iters=total_iters, title='train loss',
                                xlabel='iters', ylabel='train loss')
                vis.plot_curves({'train accuracy': correct / total}, iters=total_iters, title='train accuracy', xlabel='iters',
                                ylabel='train accuracy')

                print("Iters: {:0>6d}/[{:0>2d}], loss: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}".format(total_iters, epoch, total_loss.item(), correct/total, time_cur, exp_lr_scheduler.get_lr()[0]))

            # save model
            if total_iters % args.save_freq == 0:
                msg = 'Saving checkpoint: {}'.format(total_iters)
                _print(msg)
                if multi_gpus:
                    net_state_dict = net.module.state_dict()
                else:
                    net_state_dict = net.state_dict()
                   
                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)

                torch.save({
                    'iters': total_iters,
                    'net_state_dict': net_state_dict},
                    os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters))

            # test accuracy
            if total_iters % args.test_freq == 0 and args.has_test:
                # test model on lfw
                net.eval()
                _print('LFW Ave Accuracy: {:.4f}'.format(np.mean(lfw_accs) * 100))

                net.train()
            total_iters += 1
    print('finishing training')
示例#24
0
class Trainer(object):
    def __init__(self,
                 model,
                 optimizer,
                 configuration,
                 train_criterion,
                 train_dataloader,
                 val_dataloader,
                 val_criterion=None,
                 result_criterion=None,
                 **kwargs):

        self.config = configuration

        if torch.cuda.device_count() == 1:
            self.model = model
        else:
            print("Parallel data processing...")
            self.model = DataParallel(model)
        self.train_criterion = train_criterion

        self.best_model = None
        self.best_model_filename = osp.join(self.config.log_output_dir,
                                            self.config.best_model_name)

        if val_criterion is None:
            self.val_criterion = train_criterion
        else:
            self.val_criterion = val_criterion
        if result_criterion is None:
            print("result_criterion is None")
            self.result_criterion = self.val_criterion
        else:
            self.result_criterion = result_criterion

        self.optimizer = optimizer

        if self.config.tf:
            self.writer = SummaryWriter(log_dir=self.config.tf_dir)
            self.loss_win = 'loss_win'
            self.result_win = 'result_win'
            self.criterion_params_win = 'cparam_win'
            criterion_params = {
                k: v.data.cpu().numpy()[0]
                for k, v in self.train_criterion.named_parameters()
            }
            self.n_criterion_params = len(criterion_params)
        # set random seed
        torch.manual_seed(self.config.seed)
        if self.config.cuda:
            torch.cuda.manual_seed(self.config.seed)

        # initiate model with checkpoint
        self.start_epoch = int(1)
        if self.config["checkpoint"]:
            self.load_checkpoint()
        else:
            print("No checkpoint file")
        print('start_epoch = {}'.format(self.start_epoch))

        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader

        self.pose_m, self.pose_s = np.loadtxt(self.config.pose_stats_file)
        self.pose_m = Variable(torch.from_numpy(self.pose_m).float(),
                               requires_grad=False).cuda(async=True)
        self.pose_s = Variable(torch.from_numpy(self.pose_s).float(),
                               requires_grad=False).cuda(async=True)

        if self.config.cuda:
            self.model.cuda()
            self.train_criterion.cuda()
            self.val_criterion.cuda()

    def run(self):
        n_epochs = self.config.n_epochs
        for epoch in xrange(self.start_epoch, n_epochs + 1):
            # validate
            val_loss = None
            if self.config.do_val and ((epoch % self.config.val_freq == 0) or
                                       (epoch == n_epochs - 1)):
                val_loss = self.validate(epoch)
                if self.best_model is None or self.best_model[
                        'loss'] is None or val_loss < self.best_model['loss']:
                    self.best_model = self.pack_checkpoint(epoch, val_loss)
                    torch.save(
                        self.best_model,
                        osp.join(self.config.checkpoint_dir,
                                 self.config.best_model_name))
                    print("Best model saved at epoch {:d} in name of {:s}".
                          format(epoch, self.config.best_model_name))

            # save checkpoint
            if epoch % self.config.snapshot == 0:
                checkpoint = self.pack_checkpoint(epoch=epoch, loss=val_loss)
                fn = osp.join(self.config.checkpoint_dir,
                              "epoch_{:04d}.pth.tar".format(epoch))
                torch.save(checkpoint, fn)
                print('Epoch {:d} checkpoint saved: {:s}'.format(epoch, fn))

            self.train(epoch)

    def get_result_loss(self, output, target):
        target_var = Variable(target, requires_grad=False).cuda(async=True)
        t_loss, q_loss = self.result_criterion(output, target_var, self.pose_m,
                                               self.pose_s)

        return t_loss, q_loss

    def train(self, epoch):
        self.model.train()
        train_data_time = Timer()
        train_batch_time = Timer()
        train_data_time.tic()
        for batch_idx, (data, target) in enumerate(self.train_dataloader):
            train_data_time.toc()

            train_batch_time.tic()
            loss, output = self.step_feedfwd(data,
                                             self.model,
                                             target=target,
                                             criterion=self.train_criterion,
                                             optim=self.optimizer,
                                             train=True)

            t_loss, q_loss = self.get_result_loss(output, target)
            train_batch_time.toc()

            if batch_idx % self.config.print_freq == 0:
                n_itr = (epoch - 1) * len(self.train_dataloader) + batch_idx
                epoch_count = float(n_itr) / len(self.train_dataloader)
                print(
                    'Train {:s}: Epoch {:d}\t'
                    'Batch {:d}/{:d}\t'
                    'Data time {:.4f} ({:.4f})\t'
                    'Batch time {:.4f} ({:.4f})\t'
                    'Loss {:f}'.format(self.config.experiment, epoch,
                                       batch_idx,
                                       len(self.train_dataloader) - 1,
                                       train_data_time.last_time(),
                                       train_data_time.avg_time(),
                                       train_batch_time.last_time(),
                                       train_batch_time.avg_time(), loss))
                if self.config.tf:
                    self.writer.add_scalars(self.loss_win,
                                            {"training_loss": loss}, n_itr)
                    self.writer.add_scalars(
                        self.result_win, {
                            "training_t_loss": t_loss.item(),
                            "training_q_loss": q_loss.item()
                        }, n_itr)
                    if self.n_criterion_params:
                        for name, v in self.train_criterion.named_parameters():
                            v = v.data.cpu().numpy()[0]
                            self.writer.add_scalars(self.criterion_params_win,
                                                    {name: v}, n_itr)

            train_data_time.tic()

    def validate(self, epoch):
        # if self.visualize_val_err:
        #     L = len(self.val_dataloader)
        #     # print("L={}".format(L))
        #     batch_size = 10
        #     pred_pose = np.zeros((L * batch_size, 7))
        #     targ_pose = np.zeros((L * batch_size, 7))

        val_batch_time = Timer()  # time for step in each batch
        val_loss = AverageMeter()
        t_loss = AverageMeter()
        q_loss = AverageMeter()
        self.model.eval()
        val_data_time = Timer()  # time for data retrieving
        val_data_time.tic()
        for batch_idx, (data, target) in enumerate(self.val_dataloader):
            val_data_time.toc()

            val_batch_time.tic()
            loss, output = self.step_feedfwd(
                data,
                self.model,
                target=target,
                criterion=self.val_criterion,
                optim=self.optimizer,  # what will optimizer do in validation?
                train=False)
            # NxTx7
            val_batch_time.toc()
            val_loss.update(loss)

            t_loss_batch, q_loss_batch = self.get_result_loss(output, target)
            t_loss.update(t_loss_batch.item())
            q_loss.update(q_loss_batch.item())

            if batch_idx % self.config.print_freq == 0:
                print(
                    'Val {:s}: Epoch {:d}\t'
                    'Batch {:d}/{:d}\t'
                    'Data time {:.4f} ({:.4f})\t'
                    'Batch time {:.4f} ({:.4f})\t'
                    'Loss {:f}'.format(self.config.experiment, epoch,
                                       batch_idx,
                                       len(self.val_dataloader) - 1,
                                       val_data_time.last_time(),
                                       val_data_time.avg_time(),
                                       val_batch_time.last_time(),
                                       val_batch_time.avg_time(), loss))

            val_data_time.tic()

        # pred_pose = pred_pose.view(-1, 7)
        # targ_pose = targ_pose.view(-1, 7)
        print('Val {:s}: Epoch {:d}, val_loss {:f}'.format(
            self.config.experiment, epoch, val_loss.average()))
        print 'Mean error in translation: {:3.2f} m\n' \
              'Mean error in rotation: {:3.2f} degree'.format(t_loss.average(), q_loss.average())

        if self.config.tf:
            n_itr = (epoch - 1) * len(self.train_dataloader)
            self.writer.add_scalars(self.loss_win,
                                    {"val_loss": val_loss.average()}, n_itr)
            self.writer.add_scalars(self.result_win, {
                "val_t_loss": t_loss.average(),
                "val_q_loss": q_loss.average()
            }, n_itr)
            # self.vis.line(
            # X=np.asarray([epoch]),
            # Y=np.asarray([val_loss.average()]),
            # win=self.loss_win,
            # name='val_loss',
            # # append=True,
            # update='append',
            # env=self.vis_env
            # )
            # self.vis.line(
            # X=np.asarray([epoch]),
            # Y=np.asarray([t_loss.average()]),
            # win=self.result_win,
            # name='val_t_loss',
            # update='append',
            # env=self.vis_env
            # )
            # self.vis.line(
            # X=np.asarray([epoch]),
            # Y=np.asarray([q_loss.average()]),
            # win=self.result_win,
            # name='val_q_loss',
            # update='append',
            # env=self.vis_env
            # )
            # self.vis.save(envs=[self.vis_env])

        return t_loss.average()

    def step_feedfwd(self,
                     data,
                     model,
                     target=None,
                     criterion=None,
                     train=True,
                     **kwargs):
        optim = kwargs["optim"]
        if train:
            assert criterion is not None
            data_var = Variable(data, requires_grad=True).cuda(async=True)
            target_var = Variable(target, requires_grad=False).cuda(async=True)
        else:
            data_var = Variable(data, requires_grad=False).cuda(async=True)
            target_var = Variable(target, requires_grad=False).cuda(async=True)

        output = model(data_var)

        if criterion is not None:
            loss = criterion(output, target_var)

            if train:
                optim.zero_grad()
                loss.backward()
                if self.config.max_grad_norm > 0.0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   self.config.max_grad_norm)
                optim.step()
            return loss.data[0], output
        else:
            return 0, output

        # Help functions
    def load_checkpoint(self):
        checkpoint_file = self.config.checkpoint
        resume_optim = self.config.resume_optim
        if osp.isfile(checkpoint_file):
            loc_func = None if self.config.cuda else lambda storage, loc: storage
            # map_location: specify how to remap storage
            checkpoint = torch.load(checkpoint_file, map_location=loc_func)
            self.best_model = checkpoint
            load_state_dict(self.model, checkpoint["model_state_dict"])

            self.start_epoch = checkpoint['epoch']

            # Is this meaningful !?
            if checkpoint.has_key('criterion_state_dict'):
                c_state = checkpoint['criterion_state_dict']
                # retrieve key in train_criterion
                append_dict = {
                    k: torch.Tensor([0, 0])
                    for k, _ in self.train_criterion.named_parameters()
                    if not k in c_state
                }
                # load zeros into state_dict
                c_state.update(append_dict)
                self.train_criterion.load_state_dict(c_state)

            print("Loaded checkpoint {:s} epoch {:d}".format(
                checkpoint_file, checkpoint['epoch']))
            print("Loss of loaded model = {}".format(checkpoint['loss']))

            if resume_optim:
                print("Load parameters in optimizer")
                self.optimizer.load_state_dict(checkpoint["optim_state_dict"])
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.cuda()

            else:
                print("Notice: load checkpoint but didn't load optimizer.")
        else:
            print("Can't find specified checkpoint.!")
            exit(-1)

    def pack_checkpoint(self, epoch, loss=None):
        checkpoint_dict = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optim_state_dict': self.optimizer.state_dict(),
            'criterion_state_dict': self.train_criterion.state_dict(),
            'loss': loss
        }
        # torch.save(checkpoint_dict, filename)
        return checkpoint_dict
示例#25
0
                # errorg_meter.add(error_g.data[0])

            proBar.show(error_d.data[0], error_g.data[0])
            fix_fake_imgs = netG(fix_noises)
            # if opt.vis and ii%opt.plot_every == opt.plot_every-1:
            #     ## 可视化
            #     if os.path.exists(opt.debug_file):
            #         ipdb.set_trace()
            #     fix_fake_imgs = netg(fix_noises)
            #     vis.images(fix_fake_imgs.data.cpu().numpy()[:64]*0.5+0.5,win='fixfake')
            #     vis.images(real_img.data.cpu().numpy()[:64]*0.5+0.5,win='real')
            #     vis.plot('errord',errord_meter.value()[0])
            #     vis.plot('errorg',errorg_meter.value()[0])

        if epoch % config.DECAY_EVERY == 0:
            # 保存模型、图片
            tv.utils.save_image(fix_fake_imgs.data[:8],
                                '%s/%s.png' % (config.SAVE_PATH, epoch),
                                normalize=True,
                                range=(-1, 1))
            t.save(netD.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netG.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            # errord_meter.reset()
            # errorg_meter.reset()
            optimizer_g = t.optim.Adam(netG.parameters(),
                                       config.LR_GENERATOR,
                                       betas=(config.BETA1, 0.999))
            optimizer_d = t.optim.Adam(netD.parameters(),
                                       config.LR_DISCRIMINATOR,
                                       betas=(config.BETA1, 0.999))
示例#26
0
def train(args):
    print('start training...')
    model, model_file = create_model(args)
    #model = model.cuda()
    if torch.cuda.device_count() > 1:
        model_name = model.name
        model = DataParallel(model)
        model.name = model_name
    model = model.cuda()

    if args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=0.0001)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=0.0001)

    if args.lrs == 'plateau':
        lr_scheduler = ReduceLROnPlateau(optimizer,
                                         mode='max',
                                         factor=args.factor,
                                         patience=args.patience,
                                         min_lr=args.min_lr)
    else:
        lr_scheduler = CosineAnnealingLR(optimizer,
                                         args.t_max,
                                         eta_min=args.min_lr)
    #ExponentialLR(optimizer, 0.9, last_epoch=-1) #CosineAnnealingLR(optimizer, 15, 1e-7)

    _, val_loader = get_train_val_loaders(batch_size=args.batch_size,
                                          val_num=args.val_num)

    best_top1_acc = 0.

    print(
        'epoch |    lr    |      %        |  loss  |  avg   |  loss  |  top1  | top10  |  best  | time |  save |'
    )

    if not args.no_first_val:
        top10_acc, best_top1_acc, total_loss = validate(
            args, model, val_loader)
        print(
            'val   |          |               |        |        | {:.4f} | {:.4f} | {:.4f} | {:.4f} |      |       |'
            .format(total_loss, best_top1_acc, top10_acc, best_top1_acc))

    if args.val:
        return

    model.train()

    if args.lrs == 'plateau':
        lr_scheduler.step(best_top1_acc)
    else:
        lr_scheduler.step()
    train_iter = 0

    for epoch in range(args.start_epoch, args.epochs):
        train_loader, val_loader = get_train_val_loaders(
            batch_size=args.batch_size,
            dev_mode=args.dev_mode,
            val_num=args.val_num)

        train_loss = 0

        current_lr = get_lrs(
            optimizer)  #optimizer.state_dict()['param_groups'][2]['lr']
        bg = time.time()
        for batch_idx, data in enumerate(train_loader):
            train_iter += 1
            img, target = data
            img, target = img.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(img)

            loss = criterion(args, output, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            print('\r {:4d} | {:.6f} | {:06d}/{} | {:.4f} | {:.4f} |'.format(
                epoch, float(current_lr[0]), args.batch_size * (batch_idx + 1),
                train_loader.num, loss.item(), train_loss / (batch_idx + 1)),
                  end='')

            if train_iter > 0 and train_iter % args.iter_val == 0:
                top10_acc, top1_acc, total_loss = validate(
                    args, model, val_loader)

                _save_ckp = ''
                if args.always_save or top1_acc > best_top1_acc:
                    best_top1_acc = top1_acc
                    if isinstance(model, DataParallel):
                        torch.save(model.module.state_dict(), model_file)
                    else:
                        torch.save(model.state_dict(), model_file)
                    _save_ckp = '*'
                print(' {:.4f} | {:.4f} | {:.4f} | {:.4f} | {:.2f} |  {:4s} |'.
                      format(total_loss, top1_acc, top10_acc, best_top1_acc,
                             (time.time() - bg) / 60, _save_ckp))

                model.train()

                if args.lrs == 'plateau':
                    lr_scheduler.step(top1_acc)
                else:
                    lr_scheduler.step()
                current_lr = get_lrs(optimizer)
示例#27
0
def main(args):
    """Train/ Cross validate for data source = YogiDB."""
    # Create data loader
    """Generic(data.Dataset)(image_set, annotations,
                     is_train=True, inp_res=256, out_res=64, sigma=1,
                     scale_factor=0, rot_factor=0, label_type='Gaussian',
                     rgb_mean=RGB_MEAN, rgb_stddev=RGB_STDDEV)."""
    annotations_source = 'basic-thresholder'

    # Get the data from yogi
    db_obj = YogiDB(config.db_url)
    imageset = db_obj.get_filtered(ImageSet,
                                   name=args.image_set_name)
    annotations = db_obj.get_annotations(image_set_name=args.image_set_name,
                                         annotation_source=annotations_source)
    pts = torch.Tensor(annotations[0]['joint_self'])
    num_classes = pts.size(0)
    crop_size = 512
    if args.crop:
        crop_size = args.crop
        crop = True
    else:
        crop = False

    # Using the default RGB mean and std dev as 0
    RGB_MEAN = torch.as_tensor([0.0, 0.0, 0.0])
    RGB_STDDEV = torch.as_tensor([0.0, 0.0, 0.0])

    dataset = Generic(image_set=imageset,
                      inp_res=args.inp_res,
                      out_res=args.out_res,
                      annotations=annotations,
                      mode=args.mode,
                      crop=crop, crop_size=crop_size,
                      rgb_mean=RGB_MEAN, rgb_stddev=RGB_STDDEV)

    train_dataset = dataset
    train_dataset.is_train = True
    train_loader = DataLoader(train_dataset,
                              batch_size=args.train_batch, shuffle=True,
                              num_workers=args.workers, pin_memory=True)

    val_dataset = dataset
    val_dataset.is_train = False
    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch, shuffle=False,
                            num_workers=args.workers, pin_memory=True)

    # Select the hardware device to use for inference.
    if torch.cuda.is_available():
        device = torch.device('cuda', torch.cuda.current_device())
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    # Disable gradient calculations by default.
    torch.set_grad_enabled(False)

    # create checkpoint dir
    os.makedirs(args.checkpoint, exist_ok=True)

    if args.arch == 'hg1':
        model = hg1(pretrained=False, num_classes=num_classes)
    elif args.arch == 'hg2':
        model = hg2(pretrained=False, num_classes=num_classes)
    elif args.arch == 'hg8':
        model = hg8(pretrained=False, num_classes=num_classes)
    else:
        raise Exception('unrecognised model architecture: ' + args.model)

    model = DataParallel(model).to(device)

    if args.optimizer == "Adam":
        optimizer = Adam(model.parameters(),
                         lr=args.lr,
                         momentum=args.momentum,
                         weight_decay=args.weight_decay)
    else:
        optimizer = RMSprop(model.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    best_acc = 0

    # optionally resume from a checkpoint
    title = args.data_identifier + ' ' + args.arch
    if args.resume:
        assert os.path.isfile(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_acc = checkpoint['best_acc']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(args.resume, checkpoint['epoch']))
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    # train and eval
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # train for one epoch
        train_loss, train_acc = do_training_epoch(train_loader, model, device, optimizer)

        # evaluate on validation set
        if args.debug == 1:
            valid_loss, valid_acc, predictions, validation_log = do_validation_epoch(val_loader, model, device, False, True, os.path.join(args.checkpoint, 'debug.csv'), epoch + 1)
        else:
            valid_loss, valid_acc, predictions, _ = do_validation_epoch(val_loader, model, device, False)

        # append logger file
        logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])

        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
        }, predictions, is_best, checkpoint=args.checkpoint, snapshot=args.snapshot)

    logger.close()
    logger.plot(['Train Acc', 'Val Acc'])
    savefig(os.path.join(args.checkpoint, 'log.eps'))
示例#28
0
def Train(args, FeatExtor, DepthEstor, FeatEmbder, data_loader1_real,
          data_loader1_fake, data_loader2_real, data_loader2_fake,
          data_loader3_real, data_loader3_fake, data_loader_target,
          summary_writer, Saver, savefilename):

    ####################
    # 1. setup network #
    ####################
    # set train state for Dropout and BN layers
    FeatExtor.train()
    DepthEstor.train()
    FeatEmbder.train()

    FeatExtor = DataParallel(FeatExtor)
    DepthEstor = DataParallel(DepthEstor)

    # setup criterion and optimizer
    criterionCls = nn.BCEWithLogitsLoss()
    criterionDepth = torch.nn.MSELoss()

    if args.optimizer_meta is 'adam':

        optimizer_all = optim.Adam(itertools.chain(FeatExtor.parameters(),
                                                   DepthEstor.parameters(),
                                                   FeatEmbder.parameters()),
                                   lr=args.lr_meta,
                                   betas=(args.beta1, args.beta2))

    else:
        raise NotImplementedError('Not a suitable optimizer')

    iternum = max(len(data_loader1_real), len(data_loader1_fake),
                  len(data_loader2_real), len(data_loader2_fake),
                  len(data_loader3_real), len(data_loader3_fake))

    print('iternum={}'.format(iternum))

    ####################
    # 2. train network #
    ####################
    global_step = 0

    for epoch in range(args.epochs):

        data1_real = get_inf_iterator(data_loader1_real)
        data1_fake = get_inf_iterator(data_loader1_fake)

        data2_real = get_inf_iterator(data_loader2_real)
        data2_fake = get_inf_iterator(data_loader2_fake)

        data3_real = get_inf_iterator(data_loader3_real)
        data3_fake = get_inf_iterator(data_loader3_fake)

        for step in range(iternum):

            #============ one batch extraction ============#

            cat_img1_real, depth_img1_real, lab1_real = next(data1_real)
            cat_img1_fake, depth_img1_fake, lab1_fake = next(data1_fake)

            cat_img2_real, depth_img2_real, lab2_real = next(data2_real)
            cat_img2_fake, depth_img2_fake, lab2_fake = next(data2_fake)

            cat_img3_real, depth_img3_real, lab3_real = next(data3_real)
            cat_img3_fake, depth_img3_fake, lab3_fake = next(data3_fake)

            #============ one batch collection ============#

            catimg1 = torch.cat([cat_img1_real, cat_img1_fake], 0).cuda()
            depth_img1 = torch.cat([depth_img1_real, depth_img1_fake],
                                   0).cuda()
            lab1 = torch.cat([lab1_real, lab1_fake], 0).float().cuda()

            catimg2 = torch.cat([cat_img2_real, cat_img2_fake], 0).cuda()
            depth_img2 = torch.cat([depth_img2_real, depth_img2_fake],
                                   0).cuda()
            lab2 = torch.cat([lab2_real, lab2_fake], 0).float().cuda()

            catimg3 = torch.cat([cat_img3_real, cat_img3_fake], 0).cuda()
            depth_img3 = torch.cat([depth_img3_real, depth_img3_fake],
                                   0).cuda()
            lab3 = torch.cat([lab3_real, lab3_fake], 0).float().cuda()

            catimg = torch.cat([catimg1, catimg2, catimg3], 0)
            depth_GT = torch.cat([depth_img1, depth_img2, depth_img3], 0)
            label = torch.cat([lab1, lab2, lab3], 0)

            #============ doamin list augmentation ============#
            catimglist = [catimg1, catimg2, catimg3]
            lablist = [lab1, lab2, lab3]
            deplist = [depth_img1, depth_img2, depth_img3]

            domain_list = list(range(len(catimglist)))
            random.shuffle(domain_list)

            meta_train_list = domain_list[:args.metatrainsize]
            meta_test_list = domain_list[args.metatrainsize:]
            print('metatrn={}, metatst={}'.format(meta_train_list,
                                                  meta_test_list[0]))

            #============ meta training ============#

            Loss_dep_train = 0.0
            Loss_cls_train = 0.0

            adapted_state_dicts = []

            for index in meta_train_list:

                catimg_meta = catimglist[index]
                lab_meta = lablist[index]
                depGT_meta = deplist[index]

                batchidx = list(range(len(catimg_meta)))
                random.shuffle(batchidx)

                img_rand = catimg_meta[batchidx, :]
                lab_rand = lab_meta[batchidx]
                depGT_rand = depGT_meta[batchidx, :]

                feat_ext_all, feat = FeatExtor(img_rand)
                pred = FeatEmbder(feat)
                depth_Pre = DepthEstor(feat_ext_all)

                Loss_cls = criterionCls(pred.squeeze(), lab_rand)
                Loss_dep = criterionDepth(depth_Pre, depGT_rand)

                Loss_dep_train += Loss_dep
                Loss_cls_train += Loss_cls

                zero_param_grad(FeatEmbder.parameters())
                grads_FeatEmbder = torch.autograd.grad(Loss_cls,
                                                       FeatEmbder.parameters(),
                                                       create_graph=True)
                fast_weights_FeatEmbder = FeatEmbder.cloned_state_dict()

                adapted_params = OrderedDict()
                for (key, val), grad in zip(FeatEmbder.named_parameters(),
                                            grads_FeatEmbder):
                    adapted_params[key] = val - args.meta_step_size * grad
                    fast_weights_FeatEmbder[key] = adapted_params[key]

                adapted_state_dicts.append(fast_weights_FeatEmbder)

            #============ meta testing ============#
            Loss_dep_test = 0.0
            Loss_cls_test = 0.0

            index = meta_test_list[0]

            catimg_meta = catimglist[index]
            lab_meta = lablist[index]
            depGT_meta = deplist[index]

            batchidx = list(range(len(catimg_meta)))
            random.shuffle(batchidx)

            img_rand = catimg_meta[batchidx, :]
            lab_rand = lab_meta[batchidx]
            depGT_rand = depGT_meta[batchidx, :]

            feat_ext_all, feat = FeatExtor(img_rand)
            depth_Pre = DepthEstor(feat_ext_all)
            Loss_dep = criterionDepth(depth_Pre, depGT_rand)

            for n_scr in range(len(meta_train_list)):
                a_dict = adapted_state_dicts[n_scr]

                pred = FeatEmbder(feat, a_dict)
                Loss_cls = criterionCls(pred.squeeze(), lab_rand)

                Loss_cls_test += Loss_cls

            Loss_dep_test = Loss_dep

            Loss_dep_train_ave = Loss_dep_train / len(meta_train_list)
            Loss_dep_test = Loss_dep_test

            Loss_meta_train = Loss_cls_train + args.W_depth * Loss_dep_train
            Loss_meta_test = Loss_cls_test + args.W_depth * Loss_dep_test

            Loss_all = Loss_meta_train + args.W_metatest * Loss_meta_test

            optimizer_all.zero_grad()
            Loss_all.backward()
            optimizer_all.step()

            if (step + 1) % args.log_step == 0:
                errors = OrderedDict([
                    ('Loss_meta_train', Loss_meta_train.item()),
                    ('Loss_meta_test', Loss_meta_test.item()),
                    ('Loss_cls_train', Loss_cls_train.item()),
                    ('Loss_cls_test', Loss_cls_test.item()),
                    ('Loss_dep_train_ave', Loss_dep_train_ave.item()),
                    ('Loss_dep_test', Loss_dep_test.item()),
                ])
                Saver.print_current_errors((epoch + 1), (step + 1), errors)

            #============ tensorboard the log info ============#
            info = {
                'Loss_meta_train': Loss_meta_train.item(),
                'Loss_meta_test': Loss_meta_test.item(),
                'Loss_cls_train': Loss_cls_train.item(),
                'Loss_cls_test': Loss_cls_test.item(),
                'Loss_dep_train_ave': Loss_dep_train_ave.item(),
                'Loss_dep_test': Loss_dep_test.item(),
            }
            for tag, value in info.items():
                summary_writer.add_scalar(tag, value, global_step)

            global_step += 1

            #############################
            # 2.4 save model parameters #
            #############################
            if ((step + 1) % args.model_save_step == 0):
                model_save_path = os.path.join(args.results_path, 'snapshots',
                                               savefilename)
                mkdir(model_save_path)

                torch.save(
                    FeatExtor.state_dict(),
                    os.path.join(
                        model_save_path,
                        "FeatExtor-{}-{}.pt".format(epoch + 1, step + 1)))
                torch.save(
                    FeatEmbder.state_dict(),
                    os.path.join(
                        model_save_path,
                        "FeatEmbder-{}-{}.pt".format(epoch + 1, step + 1)))
                torch.save(
                    DepthEstor.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DepthEstor-{}-{}.pt".format(epoch + 1, step + 1)))

        if ((epoch + 1) % args.model_save_epoch == 0):
            model_save_path = os.path.join(args.results_path, 'snapshots',
                                           savefilename)
            mkdir(model_save_path)

            torch.save(
                FeatExtor.state_dict(),
                os.path.join(model_save_path,
                             "FeatExtor-{}.pt".format(epoch + 1)))
            torch.save(
                FeatEmbder.state_dict(),
                os.path.join(model_save_path,
                             "FeatEmbder-{}.pt".format(epoch + 1)))
            torch.save(
                DepthEstor.state_dict(),
                os.path.join(model_save_path,
                             "DepthEstor-{}.pt".format(epoch + 1)))

    torch.save(FeatExtor.state_dict(),
               os.path.join(model_save_path, "FeatExtor-final.pt"))
    torch.save(FeatEmbder.state_dict(),
               os.path.join(model_save_path, "FeatEmbder-final.pt"))
    torch.save(DepthEstor.state_dict(),
               os.path.join(model_save_path, "DepthEstor-final.pt"))
示例#29
0
class BaseTrainer:
    """
    Base class for all trainers.

    Attributes
    ----------
    config : parse_config.ConfigParser
        The config parsing object.

    device : torch.device
        The device the model will be trained on.

    epochs : int
        Number of epochs to train over.

    logger : logging.Logger
        Logging object.

    loss_fn : callable
        Loss function.

    loss_args : dict of {str, Any}
        Keyword arguments of the loss function.

    metric_fns : list of callable
        List of metric functions.

    metric_args : list of dict of {str, Any}
        List of keyword arguments of the metric functions, matched by index.

    mnt_best : float
        Current best recorded metric.

    mnt_mode : str
        What to monitor ("off", "max", or "min")

    model : torch.nn.Module
        The model.

    monitor : str
        Whether or not to monitor metrics (for early stopping).

    optimizer : torch.optimizer.Optimizer
        The optimizer.

    save_periods : int
        How often to save to a checkpoint.

    writer : logging.TensorBoardWriter
        Writer object for TensorBoard logging.

    Methods
    -------
    train()
        Full training logic.
    """
    def __init__(
        self,
        model: Module,
        loss_fn: Callable,
        loss_args: Dict[str, Any],
        metric_fns: List[Callable],
        metric_args: List[Dict[str, Any]],
        optimizer: Optimizer,
        config: ConfigParser,
    ):

        self.config: ConfigParser = config
        self.logger: Logger = config.get_logger("trainer",
                                                config["trainer"]["verbosity"])

        # Setup GPU device if available.
        self.device: torch.device
        device_ids: List[int]
        self.device, device_ids = self._prepare_device(config["n_gpu"])

        # Move model into configured device(s).
        self.model: Module = model.to(self.device)
        if len(device_ids) > 1:
            self.model = DataParallel(model, device_ids=device_ids)

        # Set loss function and arguments.
        self.loss_fn: Callable = loss_fn
        self.loss_args: Dict[str, Any] = loss_args

        # Set all metric functions and associated arguments.
        self.metric_fns: List[Callable] = metric_fns
        self.metric_args: List[Dict[str, Any]] = metric_args

        # Set optimizer.
        self.optimizer: Optimizer = optimizer

        # Set training configuration.
        cfg_trainer: Dict[str, Any] = config["trainer"]
        self.epochs: int = cfg_trainer["epochs"]
        self.save_period: int = cfg_trainer["save_period"]
        self.monitor: str = cfg_trainer.get("monitor", "off")

        # Configuration to monitor model performance and save best.
        if self.monitor == "off":
            self.mnt_mode: str = "off"
            self.mnt_best: float = 0
        else:
            self.mnt_metric: str
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ["min", "max"]

            self.mnt_best = inf if self.mnt_mode == "min" else -inf
            self.early_stop: float = cfg_trainer.get("early_stop", inf)

        self.start_epoch: int = 1
        self.checkpoint_dir: Path = config.save_dir

        # Setup visualization writer instance.
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer["tensorboard"])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)

    @abstractmethod
    def _train_epoch(self, epoch: int) -> Union[Dict[str, Any], NoReturn]:
        """
        Training logic for an epoch. If not implemented in child class, raise `NotImplementedError`.

        Parameters
        ----------
        epoch : int
            The current epoch.

        Returns
        -------
        dict
            A dictionary containing the logged information.

        Raises
        ------
        NotImplementedError
            If not implemented in child class.
        """
        raise NotImplementedError

    def train(self) -> None:
        """Full training logic."""
        for epoch in range(self.start_epoch, self.epochs + 1):
            result: dict = self._train_epoch(epoch)

            # Save logged information in log dict.
            log: Dict[str, float] = {"epoch": epoch}
            key: str
            value: Union[float, List[float]]
            for key, value in result.items():
                if key == "metrics":
                    assert isinstance(value, list)
                    i: int
                    mtr: str
                    log.update({
                        mtr.__name__: value[i]
                        for i, mtr in enumerate(self.metric_fns)
                    })

                elif key == "val_metrics":
                    assert isinstance(value, list)
                    log.update({
                        "val_" + mtr.__name__: value[i]
                        for i, mtr in enumerate(self.metric_fns)
                    })

                else:
                    assert isinstance(value, float)
                    log[key] = value

            # Print logged info to stdout.
            for key, value in log.items():
                self.logger.info("    {:15s}: {}".format(str(key), value))

            # Evaluate model performance in accordance with the configured metric, save best
            # checkpoint as model_best.
            best: bool = False
            if self.mnt_mode != "off":
                try:
                    # Check whether model performance improved or not, according to specified
                    # metric(mnt_metric).
                    improved: bool = (
                        self.mnt_mode == "min"
                        and log[self.mnt_metric] <= self.mnt_best) or (
                            self.mnt_mode == "max"
                            and log[self.mnt_metric] >= self.mnt_best)

                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found.\n".format(
                            self.mnt_metric) +
                        "Model performance monitoring is disabled.")
                    self.mnt_mode = "off"
                    improved = False
                    not_improved_count: int = 0

                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1

                if not_improved_count > self.early_stop:
                    self.logger.info(
                        "Validation performance didn't improve for {} ".format(
                            self.early_stop) + "epochs.\nTraining stops.")
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)

    def _prepare_device(self,
                        n_gpu_use: int) -> Tuple[torch.device, List[int]]:
        """
        Setup GPU device if available, move model into configured device.

        Parameters
        ----------
        n_gpu_use : int
            The number of GPUs to use.

        Returns
        -------
        tuple
            A tuple of the device in use and a list of device IDs.
        """
        n_gpu: int = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            self.logger.warning(
                "Warning: There's no GPU available on this machine, " +
                "training will be performed on CPU.")
            n_gpu_use = 0

        if n_gpu_use > n_gpu:
            self.logger.warning(
                "Warning: The number of GPU's configured to use is {}, ".
                format(n_gpu_use) +
                "but only {} are available on this machine.".format(n_gpu))
            n_gpu_use = n_gpu

        device: torch.device = torch.device(
            "cuda:0" if n_gpu_use > 0 else "cpu")
        list_ids: List[int] = list(range(n_gpu_use))
        return device, list_ids

    def _save_checkpoint(self, epoch: int, save_best: bool = False) -> None:
        """
        Saving current state as a checkpoint.

        Parameters
        ----------
        epoch : int
            The current epoch.

        save_best : bool, optional
            If True, the saved checkpoint is renamed to "model_best.pth" (default is False).
        """
        arch: str = type(self.model).__name__
        state: Dict[str, Any] = {
            "arch": arch,
            "epoch": epoch,
            "state_dict": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "monitor_best": self.mnt_best,
            "config": self.config,
        }
        filename: str = str(self.checkpoint_dir /
                            "checkpoint-epoch{}.pth".format(epoch))
        torch.save(state, filename)

        self.logger.info("Saving checkpoint: {} ...".format(filename))
        if save_best:
            best_path: str = str(self.checkpoint_dir / "model_best.pth")
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")

    def _resume_checkpoint(self, resume_path: Union[Path, str]) -> None:
        """
        Resume from saved checkpoints.

        Parameters
        ----------
        resume_path : pathlib.Path
            File path of the checkpoint to resume form.
        """
        resume_path = str(resume_path)
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint: dict = torch.load(resume_path)

        self.start_epoch = checkpoint["epoch"] + 1
        self.mnt_best = checkpoint["monitor_best"]

        # Load architecture params from checkpoint.
        if checkpoint["config"]["arch"] != self.config["arch"]:
            self.logger.warning(
                "Warning: Architecture configuration given in config file is different from that of"
                +
                " checkpoint. This may yield an exception while state_dict is being loaded."
            )
        self.model.load_state_dict(checkpoint["state_dict"])

        # Load optimizer state from checkpoint only when optimizer type is not changed.
        if checkpoint["config"]["optimizer"]["type"] != self.config[
                "optimizer"]["type"]:
            self.logger.warning(
                "Warning: Optimizer type given in config file is different from that of checkpoint."
                + " Optimizer parameters not being resumed.")
        else:
            self.optimizer.load_state_dict(checkpoint["optimizer"])

        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))
def train_model(train_dataset, train_num_each, val_dataset, val_num_each):
    num_train = len(train_dataset)
    num_val = len(val_dataset)

    train_useful_start_idx = get_useful_start_idx(sequence_length, train_num_each)
    #print('train_useful_start_idx ',train_useful_start_idx )
    val_useful_start_idx = get_useful_start_idx(sequence_length, val_num_each)
    #print('test_useful_start_idx ', val_useful_start_idx)

    num_train_we_use = len(train_useful_start_idx) // num_gpu * num_gpu
    # print('num_train_we_use',num_train_we_use) #92166
    num_val_we_use = len(val_useful_start_idx) // num_gpu * num_gpu
    # print('num_val_we_use', num_val_we_use)
    # num_train_we_use = 8000
    # num_val_we_use = 800

    train_we_use_start_idx = train_useful_start_idx[0:num_train_we_use]  # 训练数据开始位置
    val_we_use_start_idx = val_useful_start_idx[0:num_val_we_use]

    np.random.seed(0)
    np.random.shuffle(train_we_use_start_idx)  # 将序列的所有元素随机排序
    train_idx = []
    for i in range(num_train_we_use):  # 训练集帧数
        for j in range(sequence_length):
            train_idx.append(train_we_use_start_idx[i] + j * srate)  # 训练数据位置,每一张图是一个数据
    # print('train_idx',train_idx)

    val_idx = []
    for i in range(num_val_we_use):
        for j in range(sequence_length):
            val_idx.append(val_we_use_start_idx[i] + j * srate)
    # print('val_idx',val_idx)

    num_train_all = float(len(train_idx))
    num_val_all = float(len(val_idx))
    print('num of train dataset: {:6d}'.format(num_train))
    print('num train start idx : {:6d}'.format(len(train_useful_start_idx)))
    print('last idx train start: {:6d}'.format(train_useful_start_idx[-1]))
    print('num of train we use : {:6d}'.format(num_train_we_use))
    print('num of all train use: {:6d}'.format(int(num_train_all)))
    print('num of valid dataset: {:6d}'.format(num_val))
    print('num valid start idx : {:6d}'.format(len(val_useful_start_idx)))
    print('last idx valid start: {:6d}'.format(val_useful_start_idx[-1]))
    print('num of valid we use : {:6d}'.format(num_val_we_use))
    print('num of all valid use: {:6d}'.format(int(num_val_all)))

    val_loader = DataLoader(
        val_dataset,
        batch_size=val_batch_size,
        # sampler=val_idx,
        sampler=SeqSampler(val_dataset, val_idx),
        num_workers=workers,
        pin_memory=False
    )
    model = res34_tcn()
    if use_gpu:
        model = model.cuda()

    model = DataParallel(model)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    # model.parameters()与model.state_dict()是Pytorch中用于查看网络参数的方法。前者多见于优化器的初始化,后者多见于模型的保存
    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_accuracy = 0.0
    correspond_train_acc = 0.0

    record_np = np.zeros([epochs, 4])

    for epoch in range(epochs):
        np.random.seed(epoch)
        np.random.shuffle(train_we_use_start_idx)  # 将序列的所有元素随机排序
        train_idx = []
        for i in range(num_train_we_use):
            for j in range(sequence_length):
                train_idx.append(train_we_use_start_idx[i] + j * srate)

        train_loader = DataLoader(
            train_dataset,
            batch_size=train_batch_size,
            sampler=SeqSampler(train_dataset, train_idx),
            num_workers=workers,
            pin_memory=False
        )

        model.train()
        train_loss = 0.0
        train_corrects = 0
        train_start_time = time.time()
        num = 0
        train_num = 0
        for data in train_loader:
            num = num + 1
            # inputs, labels_phase = data
            inputs, labels_phase, kdata = data
            if use_gpu:
                inputs = Variable(inputs.cuda())  # Variable就是一个存放会变化值的地理位置,里面的值会不停发生变化
                labels = Variable(labels_phase.cuda())
                kdatas = Variable(kdata.cuda())
            else:
                inputs = Variable(inputs)
                labels = Variable(labels_phase)
                kdatas = Variable(kdata)
            optimizer.zero_grad()  # 梯度初始化为零,也就是把loss关于weight的导数变成0.
            # outputs = model.forward(inputs)  # 前向传播
            outputs = model.forward(inputs, kdatas)
            #outputs = F.softmax(outputs, dim=-1)
            _, preds = torch.max(outputs.data, -1)  # .data 获取Variable的内部Tensor;torch.max(a,1)返回每一行中最大值的那个元素,且返回其索引
            #_, yp = torch.max(y.data, 1)
            #print(yp)
            # print(yp.shape)
            print(num)
            print(preds)
            print(labels)


            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.data
            train_corrects += torch.sum(preds == labels.data)
            train_num += labels.shape[0]
            print(train_corrects.cpu().numpy() / train_num)
            if train_corrects.cpu().numpy() / train_num > 0.75:
                torch.save(copy.deepcopy(model.state_dict()), 'test.pth')  # .state_dict()只保存网络中的参数(速度快,占内存少)

        train_elapsed_time = time.time() - train_start_time

        #train_accuracy1 = train_corrects1.cpu().numpy() / train_num
        train_accuracy = train_corrects.cpu().numpy() / train_num
        train_average_loss = train_loss / train_num

        # begin eval
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        val_num = 0
        val_start_time = time.time()
        for data in val_loader:
            inputs, labels_phase, kdata = data
            #inputs, labels_phase = data
            #labels_phase = labels_phase[(sequence_length - 1)::sequence_length]
            #kdata = kdata[(sequence_length - 1)::sequence_length]
            if use_gpu:
                inputs = Variable(inputs.cuda())
                labels = Variable(labels_phase.cuda())
                kdatas = Variable(kdata.cuda())
            else:
                inputs = Variable(inputs)
                labels = Variable(labels_phase)
                kdatas = Variable(kdata)

            if crop_type == 0 or crop_type == 1:
                #outputs = model.forward(inputs)
                outputs = model.forward(inputs, kdatas)
            elif crop_type == 5:
                inputs = inputs.permute(1, 0, 2, 3, 4).contiguous()
                inputs = inputs.view(-1, 3, 224, 224)
                outputs = model.forward(inputs, kdatas)
                # outputs = model.forward(inputs)
                outputs = outputs.view(5, -1, 3)
                outputs = torch.mean(outputs, 0)
            elif crop_type == 10:
                inputs = inputs.permute(1, 0, 2, 3, 4).contiguous()
                inputs = inputs.view(-1, 3, 224, 224)
                outputs = model.forward(inputs, kdatas)
                #outputs = model.forward(inputs)
                outputs = outputs.view(10, -1, 3)
                outputs = torch.mean(outputs, 0)

            #outputs = outputs[sequence_length - 1::sequence_length]

            _, preds = torch.max(outputs.data, -1)
            #_, yp = torch.max(y.data, 1)
            print(num)
            print(preds)
            print(labels)


            loss = criterion(outputs, labels)
            #loss = 0.05 * loss1 + 0.15 * loss2 + 0.3 * loss3 + 0.5 * loss4
            #loss = 0.05 * loss1 + 0.1 * loss2 + 0.25 * loss3 + 0.6 * loss4
            val_loss += loss.data
            val_corrects += torch.sum(preds == labels.data)
            val_num += labels.shape[0]
        val_elapsed_time = time.time() - val_start_time
        val_accuracy = val_corrects.cpu().numpy() / val_num
        val_average_loss = val_loss / val_num
        print('epoch: {:4d}'
              ' train in: {:2.0f}m{:2.0f}s'
              ' train loss: {:4.4f}'
              ' train accu: {:.4f}'
              ' valid in: {:2.0f}m{:2.0f}s'
              ' valid loss: {:4.4f}'
              ' valid accu: {:.4f}'
              .format(epoch,
                      train_elapsed_time // 60,
                      train_elapsed_time % 60,
                      train_average_loss,
                      train_accuracy,
                      val_elapsed_time // 60,
                      val_elapsed_time % 60,
                      val_average_loss,
                      val_accuracy))

        if optimizer_choice == 0:
            if sgd_adjust_lr == 0:
                exp_lr_scheduler.step()
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler.step(val_average_loss)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            correspond_train_acc = train_accuracy
            best_model_wts = copy.deepcopy(model.state_dict())
        if val_accuracy == best_val_accuracy:
            if train_accuracy > correspond_train_acc:
                correspond_train_acc = train_accuracy
                best_model_wts = copy.deepcopy(model.state_dict())


        record_np[epoch, 0] = train_accuracy
        record_np[epoch, 1] = train_average_loss
        record_np[epoch, 2] = val_accuracy
        record_np[epoch, 3] = val_average_loss
        np.save(str(epoch) + '.npy', record_np)

    print('best accuracy: {:.4f} cor train accu: {:.4f}'.format(best_val_accuracy, correspond_train_acc))

    save_val = int("{:4.0f}".format(best_val_accuracy * 10000))
    save_train = int("{:4.0f}".format(correspond_train_acc * 10000))
    model_name = "tcn" \
                 + "_epoch_" + str(epochs) \
                 + "_length_" + str(sequence_length) \
                 + "_opt_" + str(optimizer_choice) \
                 + "_mulopt_" + str(multi_optim) \
                 + "_flip_" + str(use_flip) \
                 + "_crop_" + str(crop_type) \
                 + "_batch_" + str(train_batch_size) \
                 + "_train_" + str(save_train) \
                 + "_val_" + str(save_val) \
                 + ".pth"

    torch.save(best_model_wts, model_name)

    record_name = "tcn" \
                  + "_epoch_" + str(epochs) \
                  + "_length_" + str(sequence_length) \
                  + "_opt_" + str(optimizer_choice) \
                  + "_mulopt_" + str(multi_optim) \
                  + "_flip_" + str(use_flip) \
                  + "_crop_" + str(crop_type) \
                  + "_batch_" + str(train_batch_size) \
                  + "_train_" + str(save_train) \
                  + "_val_" + str(save_val) \
                  + ".npy"
    np.save(record_name, record_np)
                t.randn(CONFIG["BATCH_SIZE"], CONFIG["NOISE_DIM"], 1, 1))
            fake_img = netG(noises).detach()  # 根据噪声生成假图
            output = netD(fake_img)
            error_d_fake = criterion(output, fake_labels)
            error_d_fake.backward()
            optimizer_discriminator.step()

            error_d = error_d_fake + error_d_real

        if ii % 1 == 0:
            # 训练生成器
            netG.zero_grad()
            noises.data.copy_(
                t.randn(CONFIG["BATCH_SIZE"], CONFIG["NOISE_DIM"], 1, 1))
            fake_img = netG(noises)
            output = netD(fake_img)
            error_g = criterion(output, true_labels)
            error_g.backward()
            optimizer_generator.step()

        proBar.show(epoch, error_d.item(), error_g.item())

    # 保存模型、图片
    fix_fake_imgs = netG(fix_noises)
    tv.utils.save_image(fix_fake_imgs.data[:64],
                        'outputs/Pytorch_AnimateFace_%03d.png' % epoch,
                        normalize=True,
                        range=(-1, 1))

t.save(netG.state_dict(), "outputs/DCGAN_AnimateFace_Pytorch_Generator.pth")