Пример #1
0
def train(args, train_loader, valid_loader, model):

    model_folder = os.path.join(
        "model/", "WRN_{}_{}/".format(args.depth, args.widen_factor))
    os.makedirs(model_folder, exist_ok=True)
    history = {"acc": [], "loss": []}
    elapsed_time = 0

    print("Start training!")
    print("| Training Epochs = {}".format(args.epochs))
    print("| Initial Learning Rate = {}".format(args.lr))

    for epoch in range(args.start_epoch, args.epochs + 1):

        start_time = time.time()
        train_one_epoch(args, train_loader, model, epoch, history)
        valid_one_epoch(args, valid_loader, model, epoch)
        elapsed_time += (time.time() - start_time)
        print("| Elapsed time : %d:%02d:%02d" % (get_hms(elapsed_time)))

        if epoch % args.save_freq == 0:
            file_name = "ckpt_epoch_{}.pth".format(epoch)
            save_file = os.path.join(model_folder, file_name)
            print("==> Saving model at {}...".format(save_file))
            state = {
                "model": model.state_dict(),
                "epoch": epoch,
                "opt": args,
            }
            torch.save(state, save_file)
            del state

    print("=> Finish training")
    print("==> Saving model at {}...".format(save_file))
    file_name = "current.pth"
    save_file = os.path.join(model_folder, file_name)
    state = {
        "model": model.state_dict(),
        "epoch": epoch,
        "opt": args,
    }
    torch.save(state, save_file)
    del state
    np.save(model_folder + "history.npy".format(args.depth, args.widen_factor),
            history)
    torch.cuda.empty_cache()
Пример #2
0
 def update_query(self, q_str, warning=False):
     q_str = utils.get_hms() + " " + q_str
     utils.set_title(q_str)
     if warning:
         ""
     # self.logTxt['fg'] = "#ff5643" if warning else "#0096FF"
     # self.logTxt['text'] = qStr
     tup = tuple([q_str])
     var_str = self.get_string_var()
     if utils.var_is_empty(var_str):
         new_tup = tup
     else:
         v = utils.var_to_list(var_str)
         if len(v):
             new_tup = utils.append_tup(tuple(v), tup)
         else:
             new_tup = tup
     new_arr = list(new_tup)
     tmp_arr = []
     for item in new_arr:
         if item:
             tmp_arr.append(item)
     tup = tuple(tmp_arr)
     self.set_string_var(tup)
Пример #3
0
 def set_title2(self, title, warning=False):
     new_title = utils.get_hms() + " " + title
     utils.set_title(self.titlePrefix + "-" + new_title)
     self.update_query(new_title, warning)
     print(new_title)
Пример #4
0
            # Validate
            val_results = trn._test()
            val_writer.add_scalar('loss', val_results['mean_loss'], epoch)
            val_writer.add_scalar('acc', val_results['mean_accuracy'], epoch)
            val_writer.add_scalar('acc5', val_results['acc5'], epoch)
            acc = val_results['mean_accuracy']
            if acc > best_acc:
                print('| Saving Best model...\t\t\tTop1 = {:.2f}%'.format(acc))
                trn._save(outdir, 'model_best.pth')
                best_acc = acc

            trn._save(outdir, name='model_last.pth')
            epoch_time = time.time() - start_time
            elapsed_time += epoch_time
            print('| Elapsed time : %d:%02d:%02d\t Epoch time: %.1fs' %
                  (get_hms(elapsed_time) + (epoch_time, )))
            # Update the scheduler
            trn.step_lr()

        save_acc(outdir, best_acc, acc)
    # We are using a scheduler
    else:
        # Create the training object
        args.verbose = False
        import ray
        from ray import tune
        from ray.tune.schedulers import AsyncHyperBandScheduler
        ray.init()
        exp_name = args.outdir
        outdir = os.path.join(os.environ['HOME'], 'ray_results', exp_name)
        if not os.path.exists(outdir):
Пример #5
0
    def train(self, epoch_to_restore=0):
        if epoch_to_restore == 0:
            self.make_dirs()

        g = Generator(self.nb_channels_first_layer, self.dim)

        if epoch_to_restore > 0:
            filename_model = self.dir_models / 'epoch_{}.pth'.format(epoch_to_restore)
            g.load_state_dict(torch.load(filename_model))
        else:
            g.apply(weights_init)

        g.cuda()
        g.train()

        dataset_train = EmbeddingsImagesDataset(self.dir_z_train, self.dir_x_train)
        dataloader_train = DataLoader(dataset_train, self.batch_size, shuffle=True, num_workers=4, pin_memory=True)
        dataset_test = EmbeddingsImagesDataset(self.dir_z_test, self.dir_x_test)
        dataloader_test = DataLoader(dataset_test, self.batch_size, shuffle=True, num_workers=4, pin_memory=True)

        criterion = torch.nn.L1Loss()

        optimizer = optim.Adam(g.parameters())
        writer_train = SummaryWriter(str(self.dir_logs_train))
        writer_test = SummaryWriter(str(self.dir_logs_test))

        try:
            epoch = epoch_to_restore
            while True:
                start_time = time.time()

                g.train()
                for _ in range(self.nb_epochs_to_save):
                    epoch += 1

                    for idx_batch, current_batch in enumerate(dataloader_train):
                        g.zero_grad()
                        x = Variable(current_batch['x']).float().cuda()
                        z = Variable(current_batch['z']).float().cuda()
                        g_z = g.forward(z)

                        loss = criterion(g_z, x)
                        loss.backward()
                        optimizer.step()

                g.eval()
                with torch.no_grad():
                    train_l1_loss = AverageMeter()
                    for idx_batch, current_batch in enumerate(dataloader_train):
                        if idx_batch == 32:
                            break
                        x = current_batch['x'].float().cuda()
                        z = current_batch['z'].float().cuda()
                        g_z = g.forward(z)
                        loss = criterion(g_z, x)
                        train_l1_loss.update(loss)

                    writer_train.add_scalar('l1_loss', train_l1_loss.avg, epoch)

                    test_l1_loss = AverageMeter()
                    for idx_batch, current_batch in enumerate(dataloader_test):
                        if idx_batch == 32:
                            break
                        x = current_batch['x'].float().cuda()
                        z = current_batch['z'].float().cuda()
                        g_z = g.forward(z)
                        loss = criterion(g_z, x)
                        test_l1_loss.update(loss)

                    writer_test.add_scalar('l1_loss', test_l1_loss.avg, epoch)
                    images = make_grid(g_z.data[:16], nrow=4, normalize=True)
                    writer_test.add_image('generations', images, epoch)

                if epoch % self.nb_epochs_to_save == 0:
                    filename = os.path.join(self.dir_models, 'epoch_{}.pth'.format(epoch))
                    torch.save(g.state_dict(), filename)

                end_time = time.time()
                print("[*] Finished epoch {} in {}".format(epoch, get_hms(end_time - start_time)))

        finally:
            print('[*] Closing Writer.')
            writer_train.close()
            writer_test.close()
Пример #6
0
def train_fun(epoch, net):
    print('\n Epoch: %d' % epoch)
    net.train()

    correct = 0
    total = 0

    # update learning rate
    if epoch < args.decay_epoch1:
        lr = args.lr
    elif epoch < args.decay_epoch2:
        lr = args.lr * args.decay_rate
    else:
        lr = args.lr * args.decay_rate * args.decay_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    def get_acc(outputs, targets):
        _, predicted = outputs.max(1)
        total = targets.size(0)
        correct = predicted.eq(targets).sum().item()
        acc = 1.0 * correct / total
        return acc

    iterator = tqdm(trainloader, ncols=0, leave=False)
    for batch_idx, (inputs, targets) in enumerate(iterator):
        inputs, targets = inputs.to(device), targets.to(device)

        adv_acc = 0
        optimizer.zero_grad()
        # forward
        outputs_adv, outputs_nat, loss, nat_loss, adv_loss = net(
            inputs.detach(), targets)

        optimizer.zero_grad()
        total_loss = loss
        total_loss = total_loss.mean()
        total_loss.backward()

        optimizer.step()

        if batch_idx % args.log_step == 0:
            adv_acc = get_acc(outputs_adv, targets)
            nat_acc = get_acc(outputs_nat, targets)

            duration = time.time() - start_time
            h, m, s = get_hms(duration)
            print('\r')
            inform = "| Step %3d, lr %.4f, time %d:%02d:%02d, loss %.4f, nat acc %.2f,adv acc %.2f" % (
                batch_idx, lr, h, m, s, loss, 100 * nat_acc, 100 * adv_acc)
            iterator.set_description(str(inform))

    if not os.path.isdir(args.model_dir):
        os.mkdir(args.model_dir)
    save_point = args.model_dir + args.dataset + os.sep
    if not os.path.isdir(save_point):
        os.mkdir(save_point)
    if epoch % args.save_epochs == 0:
        correct = 0
        total = 0
        net.eval()
        for batch_idx, (inputs, targets) in enumerate(testloader):

            inputs, targets = inputs.to(device), targets.to(device)
            outputs_nat, _ = net(inputs, targets, attack=False)

            nat_acc = get_acc(outputs_nat, targets)
            correct += nat_acc
            total += 1

        print(f'| Test acc:{100.0*correct/total:.4f}')
        print('| Saving...')
        state = {
            'net': net.state_dict(),
        }
        f_path = save_point + args.save_name + f'-{epoch}.t7'
        print(f_path)
        torch.save(state, f_path)

    if epoch >= 0:
        print(f'| Saving {args.save_name} latest @ {epoch} %s...\r')
        state = {
            'net': net.state_dict(),
            'epoch': epoch,
            'optimizer': optimizer.state_dict()
        }
        f_path = save_point + args.save_name + f'-latest.t7'
        torch.save(state, f_path)
Пример #7
0
    def process(self, dc, add_str=''):
        set_title = self.start_btn.update_query
        write_log = utils.write_log

        input_list = dc["list"]
        output_dir_raw = dc["output_dir"] + os.sep
        temp_dir = output_dir_raw + 'tempDir' + os.sep
        utils.make_dir(temp_dir)
        utils.hide_file(temp_dir)

        log_txt = output_dir_raw + "log.txt"
        log_txt_str = log_txt.replace('\\', '/')

        s2bool = utils.str_to_bool
        size1_select = s2bool(dc["size1_select"])
        size2_select = s2bool(dc["size2_select"])
        # tune1_select = s2bool(dc["tune1_select"])
        tune2_select = s2bool(dc["tune2_select"])
        format1_select = s2bool(dc["format1_select"])
        format2_select = s2bool(dc["format2_select"])
        format3_select = s2bool(dc["format3_select"])
        format4_select = s2bool(dc["format4_select"])
        format5_select = s2bool(dc["format5_select"])
        keep_parent_select = s2bool(dc["keep_parent_select"])
        keep_ratio_select = s2bool(dc["keep_ratio_select"])
        double_fix_select = s2bool(dc["double_fix_select"])
        number_select = s2bool(dc["number_select"])
        number_file = dc["number_file"]
        if not number_select:
            number_file = ''

        final_mp4 = ""
        output_file = ""

        total = len(input_list)
        count = 0
        for i in range(total):
            count = count + 1
            input_file = input_list[i]
            p = Path(input_file)

            # 保留上层目录结构
            # 排除根目录
            path_root = "{0}{1}".format(p.drive, os.sep)
            path_parent = str(Path(p.parent).name)
            if keep_parent_select and not path_root == path_parent:
                output_dir = "{0}{1}{2}".format(output_dir_raw, path_parent,
                                                os.sep)
                utils.make_dir(output_dir)
                # output_file = "{0}{1}{2}".format(output_sub_dir, p.stem, ".mp3")
            else:
                output_dir = output_dir_raw
                # output_file = "{0}{1}{2}".format(output_dir, p.stem, ".mp3")

            # 任务信息
            if count < 2:
                set_title("本次转码将记录到日志:" + log_txt_str)

            msg_str_default = "  ({0}/{1})  {2}"
            msg_str = msg_str_default.format(count, total, '{}')
            ss = msg_str.format(p.stem)
            set_title(ss)

            # 写入日志
            ss = msg_str.format(p)
            log_str = '{0}	{1}'.format(utils.get_hms(), ss)
            write_log(log_txt, log_str)

            # 640
            # ffmpeg -i input -y -c:v libx264 -s 640x360 -crf 26 -r 15 -b:a 96k -ar 44100 -ac 2 -preset slower
            # -threads 8 -tune film output
            # ffmpeg -i input -y -c:v libx264 -s 1280x720 -crf 26 -r 24 -b:a 96k -ar 44100 -ac 2 -preset slower
            # -threads 8 -tune film output
            # -tune film
            # -tune animation
            if tune2_select:
                tune_str = 'film'
            else:
                tune_str = 'animation'

            obj = ff.create_obj()
            obj.input_file = input_file
            # obj.output_file = output_file
            # obj.size = size_str
            obj.crf = 26
            # obj.fps = fps
            obj.audio_bitrate = '96k'
            obj.other_param = '-ar 44100 -ac 2 -preset slower -threads 8'
            if tune_str:
                obj.tune = tune_str

            if size1_select:
                output_file = output_dir + p.stem + "_640.mp4"
                obj.size = '640x360'
                obj.fps = 24
                # mp4格式
                if format1_select:
                    set_title(msg_str.format('  正在生成 640 视频'))
                    obj.output_file = output_file
                    obj.execute()
                # flv格式
                if format2_select:
                    set_title(msg_str.format('  正在生成 640 视频(flv)'))
                    new_path = Path(output_file)
                    obj.output_file = Path.with_suffix(new_path, '.flv')
                    obj.execute()
                # m3u8 格式
                if format3_select:
                    set_title(msg_str.format('  正在生成 640 视频(m3u8)'))
                    save_dir = self.encode_m3u8(input_file, output_dir, True,
                                                tune_str)
                    output_file = save_dir

            if size2_select:
                output_file = output_dir + p.stem + "_1280.mp4"
                obj.size = '1280x720'
                obj.fps = 24
                # mp4格式
                if format1_select:
                    set_title(msg_str.format('  正在生成 1280 视频'))
                    obj.output_file = output_file
                    obj.execute()
                # flv格式
                if format2_select:
                    set_title(msg_str.format('  正在生成 1280 视频(flv)'))
                    new_path = Path(output_file)
                    obj.output_file = Path.with_suffix(new_path, '.flv')
                    obj.execute()

                # m3u8 格式
                if format3_select:
                    set_title(msg_str.format('  正在生成 1280 视频(m3u8)'))
                    save_dir = self.encode_m3u8(input_file, output_dir, False,
                                                tune_str)
                    output_file = save_dir

            # 电信 ts
            if format4_select:
                ss = msg_str.format('  正在生成 电信 ts ')
                if count < 2:
                    ss += '分辨率统一为 1280x720'
                set_title(ss)
                output_file = self.encode_teleconm_ts(input_file, output_dir,
                                                      tune_str, number_select,
                                                      number_file, '1920x1080',
                                                      15, keep_ratio_select,
                                                      double_fix_select)

            # 移动 ts
            if format5_select:
                ss = msg_str.format('  正在生成 移动 ts ')
                if count < 2:
                    ss += '分辨率统一为 1280x720 (可以在输出目录的 "移动"文件夹下找到) '
                set_title(ss)
                output_file = self.encode_china_mobile_ts(
                    input_file, output_dir, tune_str, number_select,
                    number_file, '1920x1080', 15, keep_ratio_select,
                    double_fix_select)
                # obj.full_param = arr[0]
                # obj.execute()
                # output_file = arr[1]

            final_mp4 = output_file
            obj.destroy()

        set_title("操作结束!")
        set_title("")

        # 自动打开目录
        if final_mp4:
            utils.open_file(final_mp4, True)

        self.t1 = ""
        self.lock_btn(False)

        # 检查并执行关机
        self.item_shutdown.shutdown()