Пример #1
0
class RL_TD3(RL):
    def __init__(self, env, hidden_layer=[64, 64]):
        super().__init__(env, hidden_layer=hidden_layer)
        self.env = env
        self.num_inputs = env.observation_space.shape[0]
        self.num_outputs = env.action_space.shape[0]
        self.hidden_layer = hidden_layer
        self.params = Params()
        self.actor = ActorNet(self.num_inputs, self.num_outputs,
                              self.hidden_layer).to("cpu")
        self.actor_target = ActorNet(self.num_inputs, self.num_outputs,
                                     self.hidden_layer).to("cpu")
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.q_function = QNet(self.num_inputs, self.num_outputs,
                               self.hidden_layer).to("cpu")
        self.q_function_target = QNet(self.num_inputs, self.num_outputs,
                                      self.hidden_layer).to("cpu")
        self.q_function_target.load_state_dict(self.q_function.state_dict())
        #self.actor_target = ActorNet(self.num_inputs, num_outputs, self.hidden_layer)
        #self.actor_target.load_state_dict(self.actor.state_dict())
        self.q_function.share_memory()
        self.actor.share_memory()
        self.q_function_target.share_memory()
        self.actor_target.share_memory()
        self.shared_obs_stats = Shared_obs_stats(self.num_inputs)
        self.memory = ReplayMemory(1e8)
        self.test_mean = []
        self.test_std = []
        self.lr = 1e-4
        plt.show(block=False)
        self.test_list = []

        #for multiprocessing queue
        self.queue = mp.Queue()
        self.process = []
        self.traffic_light = TrafficLight()
        self.counter = Counter()

        self.off_policy_memory = ReplayMemory(1e8)

        #self.q_function_optimizer = optim.Adam(self.q_function.parameters(), lr=self.lr, weight_decay=0e-3)
        #self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.lr, weight_decay=0e-3)
        self.q_function_optimizer = RAdam(self.q_function.parameters(),
                                          lr=self.lr,
                                          weight_decay=0e-3)
        self.actor_optimizer = RAdam(self.actor.parameters(),
                                     lr=self.lr,
                                     weight_decay=0e-3)
        self.actor.train()
        self.q_function.train()

        self.q_fucntion_scheduler = optim.lr_scheduler.ExponentialLR(
            self.q_function_optimizer, gamma=0.99)
        self.actor_scheduler = optim.lr_scheduler.ExponentialLR(
            self.actor_optimizer, gamma=0.99)

        self.off_policy_queue = mp.Queue()

        self.reward_scale = 1
        self.max_reward = mp.Value("f", 50)
        #self.actor_target.share_memory()

    def run_test(self, num_test=1):
        state = self.env.reset()  #_for_test()
        state = Variable(torch.Tensor(state).unsqueeze(0))
        ave_test_reward = 0

        total_rewards = []
        #actor_old = ActorNet(self.num_inputs, self.num_outputs, self.hidden_layer)
        #actor_old.load_state_dict(self.actor.state_dict())

        for i in range(num_test):
            total_reward = 0
            num_step = 0
            while True:
                state = self.shared_obs_stats.normalize(state).to("cpu")
                #print(self.actor.device)
                action, _, mean, log_std = self.actor.sample(state)
                #print("log std", log_std)
                print(mean.data)
                env_action = mean.cpu().data.squeeze().numpy()
                state, reward, done, _ = self.env.step(env_action)

                total_reward += reward
                num_step += 1
                #done = done or (num_step > 400)

                if done:
                    state = self.env.reset()  #_for_test()
                    #print(self.env.position)
                    #print(self.env.time)
                    state = Variable(torch.Tensor(state).unsqueeze(0))
                    ave_test_reward += total_reward / num_test
                    total_rewards.append(total_reward)
                    break
                state = Variable(torch.Tensor(state).unsqueeze(0))

        reward_mean = statistics.mean(total_rewards)
        reward_std = statistics.stdev(total_rewards)
        self.test_mean.append(reward_mean)
        self.test_std.append(reward_std)
        self.test_list.append((reward_mean, reward_std))

    def plot_statistics(self):
        ax = self.fig.add_subplot(111)
        low = []
        high = []
        index = []
        for i in range(len(self.test_mean)):
            low.append(self.test_mean[i] - self.test_std[i])
            high.append(self.test_mean[i] + self.test_std[i])
            index.append(i)
        #ax.set_xlim([0,1000])
        #ax.set_ylim([0,300])
        plt.xlabel('iterations')
        plt.ylabel('average rewards')
        ax.plot(self.test_mean, 'b')
        ax.fill_between(index, low, high, color='cyan')
        #ax.plot(map(sub, test_mean, test_std))
        self.fig.canvas.draw()

    def collect_samples(self, num_samples, noise=-2.0, random_seed=1):
        torch.set_num_threads(1)
        #print(random_seed)
        #env.seed(random_seed+3)
        #random seed is used to make sure different thread generate different trajectories
        random.seed(random_seed)
        torch.manual_seed(random_seed + 1)
        np.random.seed(random_seed + 2)
        #torch.cuda.manual_seed_all(random_seed+3)
        start_state = self.env.reset()
        samples = 0
        done = False
        states = []
        next_states = []
        actions = []
        rewards = []
        q_values = []
        dones = []
        self.actor.set_noise(noise)
        state = start_state

        state = Variable(torch.Tensor(state).unsqueeze(0))
        total_reward = 0
        start = t.time()
        while True:
            #actor_old = ActorNet(self.num_inputs, self.num_outputs, self.hidden_layer)
            #actor_old.load_state_dict(self.actor.state_dict())
            #print("something")
            self.actor.load_state_dict(torch.load(self.model_path))
            signal_init = self.traffic_light.get()
            while samples < num_samples and not done:
                state = self.shared_obs_stats.normalize(state).to("cpu")
                states.append(state.cpu().data.numpy())
                if self.traffic_light.explore.value == False:  # and random.randint(0,90)%100 > 0:
                    action, _, mean, _ = self.actor.sample(state)
                    action.detach()
                    mean.detach()
                    #print(action)
                else:
                    action = np.random.randint(
                        -100, 100,
                        size=(self.env.action_space.shape[0], )) * 1.0 / 100.0
                    #action = self.env.action_space.sample()
                    action = Variable(torch.Tensor(action).unsqueeze(0))
                actions.append(action.cpu().data.numpy())
                env_action = action.cpu().data.squeeze().numpy()

                state, reward, done, _ = self.env.step(env_action)
                if reward > self.max_reward.value:
                    self.max_reward.value = min(reward, 50.0)
                #print(env_action)
                #print(samples, env_action, reward, state)

                #print(reward)
                total_reward += reward
                #print(samples, total_reward)
                rewards.append(
                    Variable(reward * torch.ones(1, 1)).data.numpy())
                state = Variable(torch.Tensor(state).unsqueeze(0))
                #print(state.shape)
                next_state = self.shared_obs_stats.normalize(state)
                next_states.append(next_state.cpu().data.numpy())
                dones.append(
                    Variable((1 - done) * torch.ones(1, 1)).data.numpy())
                samples += 1
                #done = (done or samples > num_samples)

            self.queue.put([states, actions, next_states, rewards, dones])
            #print(self.actor.p_fcs[0].bias.data[0])
            self.counter.increment()
            #print("waiting sim time passed", t.time() - start)
            start = t.time()
            while self.traffic_light.get() == signal_init:
                pass
            start = t.time()
            state = self.env.reset()
            state = Variable(torch.Tensor(state).unsqueeze(0))
            samples = 0
            print(total_reward)
            total_reward = 0
            done = False
            states = []
            next_states = []
            actions = []
            rewards = []
            values = []
            q_values = []
            dones = []

    def collect_expert_samples(self,
                               num_samples,
                               filename,
                               noise=-2.0,
                               validation=False,
                               difficulty=[0, 0]):
        import gym
        expert_env = gym.make("mocca_envs:Walker3DStepperEnv-v0")
        expert_env.set_difficulty(difficulty)
        start_state = expert_env.reset()
        samples = 0
        done = False
        states = []
        next_states = []
        actions = []
        rewards = []
        q_values = []
        dones = []
        model_expert = self.Net(self.num_inputs, self.num_outputs,
                                [256, 256, 256, 256, 256])

        model_expert.load_state_dict(torch.load(filename))
        policy_noise = noise * np.ones(self.num_outputs)
        model_expert.set_noise(policy_noise)

        state = start_state
        state = Variable(torch.Tensor(state).unsqueeze(0))
        total_reward = 0
        total_sample = 0
        #q_value = Variable(torch.zeros(1, 1))
        if validation:
            max_sample = 300
        else:
            max_sample = 10000
        while total_sample < max_sample:
            score = 0
            while samples < num_samples and not done:
                state = self.shared_obs_stats.normalize(state)

                states.append(state.data.numpy())
                mu = model_expert.sample_best_actions(state)
                actions.append(mu.data.numpy())
                eps = torch.randn(mu.size())
                if validation:
                    weight = 0.1
                else:
                    weight = 0.1
                env_action = model_expert.sample_actions(state)
                env_action = env_action.data.squeeze().numpy()

                state, reward, done, _ = expert_env.step(env_action)
                dones.append(
                    Variable((1 - done) * torch.ones(1, 1)).data.numpy())
                rewards.append(
                    Variable(reward * torch.ones(1, 1)).data.numpy())
                state = Variable(torch.Tensor(state).unsqueeze(0))

                next_state = self.shared_obs_stats.normalize(state)
                next_states.append(next_state.data.numpy())

                samples += 1
                #total_sample += 1
                score += reward
            print("expert score", score)
            # state = self.shared_obs_stats.normalize(state)
            # v = model_expert.get_value(state)
            # if done:
            #     R = torch.zeros(1, 1)
            # else:
            #     R = v.data
            #     R = torch.ones(1, 1) * 100
            # R = Variable(R)
            # for i in reversed(range(len(rewards))):
            #     R = self.params.gamma * R + Variable(torch.from_numpy(rewards[i]))
            #     q_values.insert(0, R.data.numpy())

            if not validation and score >= num_samples:
                self.off_policy_memory.push(
                    [states, actions, next_states, rewards, dones])
                total_sample += num_samples
            elif score >= num_samples:
                self.validation_trajectory.push(
                    [states, actions, next_states, rewards, dones])
            start_state = expert_env.reset()
            state = start_state
            state = Variable(torch.Tensor(state).unsqueeze(0))
            total_reward = 0
            samples = 0
            done = False
            states = []
            next_states = []
            actions = []
            rewards = []
            q_values = []

    def update_q_function(self, batch_size, num_epoch, update_actor=False):
        for k in range(num_epoch):
            batch_states, batch_actions, batch_next_states, batch_rewards, batch_dones = self.off_policy_memory.sample(
                batch_size)
            # batch_states2, batch_actions2, batch_next_states2, batch_rewards2, batch_dones2 = self.memory.sample(self.num_threads)

            batch_states = Variable(torch.Tensor(batch_states)).to(device)
            batch_next_states = Variable(
                torch.Tensor(batch_next_states)).to(device)
            batch_actions = Variable(torch.Tensor(batch_actions)).to(device)
            batch_rewards = Variable(
                torch.Tensor(batch_rewards / self.max_reward.value)).to(device)
            batch_dones = Variable(torch.Tensor(batch_dones)).to(device)

            # batch_states2 = Variable(torch.Tensor(batch_states2))
            # batch_next_states2 = Variable(torch.Tensor(batch_next_states2))
            # batch_actions2 = Variable(torch.Tensor(batch_actions2))
            # batch_rewards2 = Variable(torch.Tensor(batch_rewards2 * self.reward_scale))
            # batch_dones2 = Variable(torch.Tensor(batch_dones2))

            # batch_states = torch.cat([batch_states, batch_states2], 0).to(device)
            # batch_next_states = torch.cat([batch_next_states, batch_next_states2], 0).to(device)
            # batch_actions = torch.cat([batch_actions, batch_actions2], 0).to(device)
            # batch_rewards = torch.cat([batch_rewards, batch_rewards2], 0).to(device)
            # batch_dones = torch.cat([batch_dones, batch_dones2], 0).to(device)

            #compute on policy actions for next state
            batch_next_state_action, batch_next_log_prob, batch_next_state_action_mean, _, = self.actor_target.sample_gpu(
                batch_next_states)
            #compute q value for these actions
            q_next_1_target, q_next_2_target = self.q_function_target(
                batch_next_states, batch_next_state_action)
            q = torch.min(q_next_1_target, q_next_2_target)

            #value functions estimate of the batch_states
            value = batch_rewards + batch_dones * self.params.gamma * q

            #q value estimate
            q1, q2 = self.q_function(batch_states, batch_actions)
            #print(q1.shape, value.shape)
            q1_value_loss = F.mse_loss(q1, value)
            q2_value_loss = F.mse_loss(q2, value)
            q_value_loss = q1_value_loss + q2_value_loss
            #print(q_value_loss)

            self.q_function_optimizer.zero_grad()
            q_value_loss.backward()
            self.q_function_optimizer.step()

            if update_actor is False:
                continue

            mean_action, log_std = self.actor(batch_states)
            q1_new, q2_new = self.q_function(batch_states, mean_action)
            new_q_value = torch.min(q1_new,
                                    q2_new)  # - self.critic(batch_states)
            policy_loss = (-new_q_value).mean() + self.action_weight * (
                mean_action**2).mean()
            #print("policy_loss",  (-new_q_value).mean())
            #print("log_prob", log_prob.shape, new_q_value.shape)

            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            self.actor_optimizer.step()

    def update_q_target(self, tau):
        for target_param, param in zip(self.q_function_target.parameters(),
                                       self.q_function.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) +
                                    param.data * tau)

    def update_actor_target(self, tau):
        for target_param, param in zip(self.actor_target.parameters(),
                                       self.actor.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) +
                                    param.data * tau)

    def save_actor(self, filename):
        torch.save(self.actor.state_dict(), filename)

    def collect_samples_multithread(self):
        import time
        self.num_samples = 0
        self.start = time.time()
        self.num_threads = 1
        self.action_weight = 0.01
        self.lr = 1e-4
        self.traffic_light.explore.value = True
        self.time_passed = 0
        max_samples = 0
        seeds = [
            np.random.randint(0, 4294967296) for _ in range(self.num_threads)
        ]

        ts = [
            mp.Process(target=self.collect_samples,
                       args=(500, ),
                       kwargs={
                           'noise': -2.0,
                           'random_seed': seed
                       }) for seed in seeds
        ]
        for t in ts:
            t.start()

        for iter in range(1000000):
            while len(self.memory.memory) < max_samples:
                if self.counter.get() == self.num_threads:
                    for i in range(self.num_threads):
                        #if random.randint(0, 1) == 0:
                        self.memory.push(self.queue.get())
                        # else:
                        #     self.memory.push_half(self.queue.get())
                    self.counter.increment()
                if self.counter.get() == self.num_threads + 1:
                    break
            print(len(self.memory.memory))
            off_policy_memory_len = len(self.off_policy_memory.memory)
            #print(off_policy_memory_len)
            memory_len = len(self.memory.memory)
            #print(len(self.memory.memory))
            #self.update_critic(128, 640 * int(memory_len/3000))
            if off_policy_memory_len >= 128:
                if off_policy_memory_len > 10000:
                    self.traffic_light.explore.value = False
                else:
                    print("explore")
                self.actor.to(device)
                self.q_function.to(device)
                self.actor_target.to(device)
                self.q_function_target.to(device)
                #for policy_update in range(len(self.memory.memory)):
                for policy_update in range(32):
                    if policy_update % 2 == 0:
                        self.update_q_function(128, 100)
                    else:
                        self.update_q_function(128, 1, update_actor=True)
                        self.update_q_target(0.005)
                        self.update_actor_target(0.005)
                self.actor.to("cpu")
                self.q_function.to("cpu")
                self.actor_target.to("cpu")
                self.q_function_target.to("cpu")

                #print(self.actor.p_fcs[0].bias.data[0])
                self.num_samples += memory_len
                self.save_actor(self.model_path)
                if iter % 10 == 0:
                    self.run_test(num_test=2)
                    self.plot_statistics()
                    print(self.num_samples, self.test_mean[-1])

            self.off_policy_memory.memory = self.off_policy_memory.memory + self.memory.memory
            #if (math.isnan(len(self.memory.memory))):
            #	print(self.memory.memory)
            #print(self.off_policy_memory.memory)
            self.clear_memory()
            self.off_policy_memory.clean_memory()
            #start = t.time()
            #print("waiting memory collectd time passed", t.time() - start)
            self.traffic_light.switch()
            self.counter.reset()
Пример #2
0
def train_eval(args, train_data, dev_data, positions):
    _bbox_collate_fn = partial(bbox_collate_fn, max_bb_num=args.bb_num)

    # Create dataset & dataloader
    trans = [
        PadResize(224),
        transforms.RandomRotation(degrees=args.aug_rot),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=args.aug_erase_p,
                                 scale=(args.aug_erase_min,
                                        args.aug_erase_max))
    ]
    trans = transforms.Compose(trans)
    dev_trans = [
        PadResize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]
    dev_trans = transforms.Compose(dev_trans)

    train_dataset, train_char_idx = \
        create_datasetBB(args.root, train_data, positions,
            post_crop_transform=trans,
            collate_fn_bbox=_bbox_collate_fn,
            bbox_scale=args.bb_scale)

    train_sampler = MetricBatchSampler(train_dataset,
                                       train_char_idx,
                                       n_max_per_char=args.n_max_per_char,
                                       n_batch_size=args.n_batch_size,
                                       n_random=args.n_random)
    train_dataloader = DataLoader(train_dataset,
                                  batch_sampler=train_sampler,
                                  batch_size=1,
                                  num_workers=5)

    eval_train_dataloaders = \
        prepare_evaluation_dataloadersBB(args, args.eval_split*3, train_data, positions,
            post_crop_transform=dev_trans,
            collate_fn_bbox=_bbox_collate_fn,
            bbox_scale=args.bb_scale
        )
    eval_dev_dataloaders = \
        prepare_evaluation_dataloadersBB(args, args.eval_split, dev_data, positions,
            post_crop_transform=dev_trans,
            collate_fn_bbox=_bbox_collate_fn,
            bbox_scale=args.bb_scale
        )

    # Construct model & optimizer
    device = "cpu" if args.gpu < 0 else "cuda:{}".format(args.gpu)

    trunk, model = create_models(args.emb_dim, args.dropout)
    trunk.to(device)
    model.to(device)

    if args.optimizer == "SGD":
        optimizer = torch.optim.SGD(list(trunk.parameters()) +
                                    list(model.parameters()),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.decay)
    elif args.optimizer == "Adam":
        optimizer = torch.optim.Adam(list(trunk.parameters()) +
                                     list(model.parameters()),
                                     lr=args.lr,
                                     weight_decay=args.decay)
    elif args.optimizer == "RAdam":
        optimizer = RAdam(list(trunk.parameters()) + list(model.parameters()),
                          lr=args.lr,
                          weight_decay=args.decay)

    def lr_func(step):
        if step < args.warmup:
            return (step + 1) / args.warmup
        else:
            steps_decay = step // args.decay_freq
            return 1 / args.decay_factor**steps_decay

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_func)
    if args.optimizer == "RAdam":
        scheduler = None

    best_dev_eer = 1.0
    for i_epoch in range(args.epoch):
        logger.info(f"EPOCH: {i_epoch}")

        bar = tqdm(total=len(train_dataloader), smoothing=0.0)
        for (img, mask), labels in train_dataloader:
            optimizer.zero_grad()

            img, mask = img.to(device), mask.to(device)
            embedding = model(trunk([img, mask]))

            a_idx, p_idx, n_idx = get_all_triplets_indices(labels)
            if a_idx.size(0) == 0:
                logger.info("Zero triplet. Skip.")
                continue
            anchors, positives, negatives = embedding[a_idx], embedding[
                p_idx], embedding[n_idx]
            a_p_dist = -sim_func(anchors, positives)
            a_n_dist = -sim_func(anchors, negatives)

            dist = a_p_dist - a_n_dist
            loss_modified = dist + args.margin
            relued = torch.nn.functional.relu(loss_modified)
            num_non_zero_triplets = (relued > 0).nonzero().size(0)
            if num_non_zero_triplets > 0:
                loss = torch.sum(relued) / num_non_zero_triplets
                loss.backward()
                optimizer.step()

            if scheduler is not None:
                scheduler.step()
            bar.update()
        bar.close()

        if i_epoch % args.eval_freq == 0:
            train_eer, train_eer_std = evaluate(args,
                                                trunk,
                                                model,
                                                eval_train_dataloaders,
                                                sim_func=sim_func_pair)
            dev_eer, dev_eer_std = evaluate(args,
                                            trunk,
                                            model,
                                            eval_dev_dataloaders,
                                            sim_func=sim_func_pair)
            logger.info("Train EER (mean, std):\t{}\t{}".format(
                train_eer, train_eer_std))
            logger.info("Eval EER (mean, std):\t{}\t{}".format(
                dev_eer, dev_eer_std))
            if dev_eer < best_dev_eer:
                logger.info("New best model!")
                best_dev_eer = dev_eer

                if args.save_model:
                    save_models = {
                        "trunk": trunk.state_dict(),
                        "embedder": model.state_dict(),
                        "args": [args.emb_dim, args.dropout]
                    }
                    torch.save(save_models, f"model/{args.suffix}.mdl")

    return best_dev_eer
Пример #3
0
    correct = 0.
    total = 0.

    progress_bar = tqdm(train_loader)
    for i, (images, labels) in enumerate(progress_bar):
        progress_bar.set_description('Epoch ' + str(epoch))

        images = images.cuda()
        labels = labels.cuda()

        cnn.zero_grad()
        pred = cnn(images)

        xentropy_loss = criterion(pred, labels)
        xentropy_loss.backward()
        cnn_optimizer.step()

        xentropy_loss_avg += xentropy_loss.item()

        # Calculate running average of accuracy
        pred = torch.max(pred.data, 1)[1]
        total += labels.size(0)
        correct += (pred == labels.data).sum().item()
        accuracy = correct / total
        xentropy = xentropy_loss_avg / (i + 1)

        progress_bar.set_postfix(xentropy='%.4f' % (xentropy),
                                 acc='%.4f' % accuracy)

    if args.lookahead:
        cnn_optimizer._backup_and_load_cache()
Пример #4
0
class Trainer:
    def __init__(self, args, train_loader, test_loader, tokenizer):
        self.args = args
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.tokenizer = tokenizer
        self.vocab_size = tokenizer.vocab_size
        self.pad_id = tokenizer.pad_token_id
        self.eos_id = tokenizer.eos_token_id
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() and not args.no_cuda else
            'cpu', args.local_rank)
        self.writer = SummaryWriter() if args.local_rank in [-1, 0] else None
        self.n_gpus = torch.distributed.get_world_size(
        ) if args.distributed else torch.cuda.device_count()
        assert args.pretrain != args.finetune  # Do not set both finetune and pretrain arguments to the same (True, False)

        if args.pretrained_model:
            self.gpt = torch.load(args.pretrained_model)
        else:
            self.gpt = GPT(vocab_size=self.vocab_size,
                           seq_len=args.max_seq_len,
                           d_model=args.hidden,
                           n_layers=args.n_layers,
                           n_heads=args.n_attn_heads,
                           d_ff=args.ffn_hidden,
                           embd_pdrop=args.embd_dropout,
                           attn_pdrop=args.attn_dropout,
                           resid_pdrop=args.resid_dropout,
                           pad_id=self.pad_id)

        if args.pretrain:
            self.model = GPTLMHead(self.gpt)
            self.model.to(self.device)
        if args.finetune:
            with open(args.cached_label_dict, 'r') as file:
                label_dict = json.load(file)
            self.model = GPTClsHead(self.gpt,
                                    n_class=len(label_dict),
                                    cls_token_id=self.eos_id)
            self.model.to(self.device)

        if args.distributed:
            self.model = DistributedDataParallel(self.model,
                                                 device_ids=[args.local_rank],
                                                 output_device=args.local_rank)

        self.optimizer = RAdam(self.model.parameters(), args.lr)
        self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_id).to(
            self.device)
        self.cls_criterion = nn.CrossEntropyLoss().to(self.device)

    @timeit
    def train(self, epoch):
        if self.args.pretrain:
            self.pretrain(epoch)
        if self.args.finetune:
            self.finetune(epoch)

    def pretrain(self, epoch):
        losses = 0
        n_batches, n_samples = len(self.train_loader), len(
            self.train_loader.dataset)

        self.model.train()
        for i, batch in enumerate(self.train_loader):
            inputs = batch[0].to(self.device)
            targets = inputs[:, 1:].contiguous()
            # |inputs| : (batch_size, seq_len), |targets| : (batch_size, seq_len-1)

            lm_logits = self.model(inputs)
            lm_logits = lm_logits[:, :-1].contiguous()
            # |lm_logits| : (batch_size, seq_len-1, vocab_size)

            loss = self.criterion(lm_logits.view(-1, self.vocab_size),
                                  targets.view(-1))
            losses += loss.item()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if self.args.local_rank in [-1, 0]:
                self.writer.add_scalar('Loss/pre-train', loss.item(),
                                       ((epoch - 1) * n_batches) + i)
                if i % (n_batches // 5) == 0 and i != 0:
                    print('Iteration {} ({}/{})\tLoss: {:.4f}'.format(
                        i, i, n_batches, losses / i))

        print('Train Epoch {} [rank: {}]\t>\tLoss: {:.4f}'.format(
            epoch, self.args.local_rank, losses / n_batches))

    def finetune(self, epoch):
        losses, accs = 0, 0
        n_batches, n_samples = len(self.train_loader), len(
            self.train_loader.dataset)  # n_batches = batch size per GPU

        self.model.train()
        for i, batch in enumerate(self.train_loader):
            inputs, labels = map(lambda x: x.to(self.device), batch)
            # |inputs| : (batch_size, seq_len), |labels| : (batch_size)

            lm_logits, cls_logits = self.model(inputs)
            lm_logits = lm_logits[:, :-1].contiguous()
            # |lm_logits| : (batch_size, seq_len-1, vocab_size), |cls_logits| : (batch_size, n_class)

            lm_loss = self.criterion(lm_logits.view(-1, self.vocab_size),
                                     inputs[:, 1:].contiguous().view(-1))
            cls_loss = self.cls_criterion(cls_logits, labels)
            loss = cls_loss + (self.args.auxiliary_ratio * lm_loss)

            losses += loss.item()
            acc = (cls_logits.argmax(dim=-1) == labels).to(
                dtype=cls_logits.dtype).mean()
            accs += acc

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if self.args.local_rank in [-1, 0]:
                self.writer.add_scalar('Loss/fine-tune', loss.item(),
                                       ((epoch - 1) * n_batches) + i)
                self.writer.add_scalar('Accuracy/fine-tune', acc,
                                       ((epoch - 1) * n_batches) + i)
                if i % (n_batches // 5) == 0 and i != 0:
                    print('Iteration {} ({}/{})\tLoss: {:.4f} Acc: {:.1f}%'.
                          format(i, i, n_batches, losses / i, accs / i * 100.))

        print(
            'Train Epoch {} [rank: {}]\t>\tLoss: {:.4f} / Acc: {:.1f}%'.format(
                epoch, self.args.local_rank, losses / n_batches,
                accs / n_batches * 100.))

    def evaluate(self, epoch):
        losses, accs = 0, 0
        n_batches, n_samples = len(self.test_loader), len(
            self.test_loader.dataset)

        self.model.eval()
        with torch.no_grad():
            for i, batch in enumerate(self.test_loader):
                if self.args.pretrain:
                    inputs = batch.to(self.device)
                    targets = inputs[:, 1:].contiguous()

                    lm_logits = self.model(inputs)
                    lm_logits = lm_logits[:, :-1].contiguous()

                    loss = self.criterion(lm_logits.view(-1, self.vocab_size),
                                          targets.view(-1))
                    losses += loss.item()

                    if self.args.local_rank in [-1, 0]:
                        self.writer.add_scalar('Loss/pre-train(eval)',
                                               loss.item(),
                                               ((epoch - 1) * n_batches) + i)

                elif self.args.finetune:
                    inputs, labels = map(lambda x: x.to(self.device), batch)

                    lm_logits, cls_logits = self.model(inputs)
                    lm_logits = lm_logits[:, :-1].contiguous()

                    lm_loss = self.criterion(
                        lm_logits.view(-1, self.vocab_size),
                        inputs[:, 1:].contiguous().view(-1))
                    cls_loss = self.cls_criterion(cls_logits, labels)
                    loss = cls_loss + (self.args.auxiliary_ratio * lm_loss)

                    losses += loss.item()
                    acc = (cls_logits.argmax(dim=-1) == labels).to(
                        dtype=cls_logits.dtype).mean()
                    accs += acc

                    if self.args.local_rank in [-1, 0]:
                        self.writer.add_scalar('Loss/fine-tune(eval)',
                                               loss.item(),
                                               ((epoch - 1) * n_batches) + i)
                        self.writer.add_scalar('Accuracy/fine-tune(eval)', acc,
                                               ((epoch - 1) * n_batches) + i)

        print(
            'Eval Epoch {} [rank: {}]\t>\tLoss: {:.4f} / Acc: {:.1f}%'.format(
                epoch, self.args.local_rank, losses / n_batches,
                accs / n_batches * 100.))

    def save(self, epoch, model_prefix='model', root='.model'):
        path = Path(root) / (model_prefix + '.ep%d' % epoch)
        if not path.parent.exists():
            path.parent.mkdir()

        if self.args.distributed:
            if self.args.local_rank == 0:
                torch.save(self.gpt, path)
        else:
            torch.save(self.gpt, path)
Пример #5
0
def main():
    parser = argparse.ArgumentParser()
    # path setting
    parser.add_argument("--waveforms",
                        type=str, help="directory or list of wav files")
    parser.add_argument("--waveforms_eval",
                        type=str, help="directory or list of evaluation wav files")
    parser.add_argument("--feats", required=True,
                        type=str, help="directory or list of wav files")
    parser.add_argument("--feats_eval", required=True,
                        type=str, help="directory or list of evaluation feat files")
    parser.add_argument("--stats", required=True,
                        type=str, help="directory or list of evaluation wav files")
    parser.add_argument("--expdir", required=True,
                        type=str, help="directory to save the model")
    # network structure setting
    parser.add_argument("--upsampling_factor", default=120,
                        type=int, help="number of dimension of aux feats")
    parser.add_argument("--hidden_units_wave", default=384,
                        type=int, help="depth of dilation")
    parser.add_argument("--hidden_units_wave_2", default=16,
                        type=int, help="depth of dilation")
    parser.add_argument("--kernel_size_wave", default=7,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--dilation_size_wave", default=1,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--lpc", default=12,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--mcep_dim", default=50,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--right_size", default=0,
                        type=int, help="kernel size of dilated causal convolution")
    # network training setting
    parser.add_argument("--lr", default=1e-4,
                        type=float, help="learning rate")
    parser.add_argument("--batch_size", default=15,
                        type=int, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--epoch_count", default=4000,
                        type=int, help="number of training epochs")
    parser.add_argument("--do_prob", default=0,
                        type=float, help="dropout probability")
    parser.add_argument("--batch_size_utt", default=5,
                        type=int, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--batch_size_utt_eval", default=5,
                        type=int, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--n_workers", default=2,
                        type=int, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--n_quantize", default=256,
                        type=int, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--causal_conv_wave", default=False,
                        type=strtobool, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--n_stage", default=4,
                        type=int, help="number of sparsification stages")
    parser.add_argument("--t_start", default=20000,
                        type=int, help="iter idx to start sparsify")
    parser.add_argument("--t_end", default=4500000,
                        type=int, help="iter idx to finish densitiy sparsify")
    parser.add_argument("--interval", default=100,
                        type=int, help="interval in finishing densitiy sparsify")
    parser.add_argument("--densities", default="0.05-0.05-0.2",
                        type=str, help="final densitiy of reset, update, new hidden gate matrices")
    # other setting
    parser.add_argument("--pad_len", default=3000,
                        type=int, help="seed number")
    parser.add_argument("--save_interval_iter", default=5000,
                        type=int, help="interval steps to logr")
    parser.add_argument("--save_interval_epoch", default=10,
                        type=int, help="interval steps to logr")
    parser.add_argument("--log_interval_steps", default=50,
                        type=int, help="interval steps to logr")
    parser.add_argument("--seed", default=1,
                        type=int, help="seed number")
    parser.add_argument("--resume", default=None,
                        type=str, help="model path to restart training")
    parser.add_argument("--pretrained", default=None,
                        type=str, help="model path to restart training")
    parser.add_argument("--string_path", default=None,
                        type=str, help="model path to restart training")
    parser.add_argument("--GPU_device", default=None,
                        type=int, help="selection of GPU device")
    parser.add_argument("--verbose", default=1,
                        type=int, help="log level")
    args = parser.parse_args()

    if args.GPU_device is not None:
        os.environ["CUDA_DEVICE_ORDER"]     = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"]  = str(args.GPU_device)

    # make experimental directory
    if not os.path.exists(args.expdir):
        os.makedirs(args.expdir)

    # set log level
    if args.verbose == 1:
        logging.basicConfig(level=logging.INFO,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    elif args.verbose > 1:
        logging.basicConfig(level=logging.DEBUG,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    else:
        logging.basicConfig(level=logging.WARN,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
        logging.warn("logging is disabled.")

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if str(device) == "cpu":
        raise ValueError('ERROR: Training by CPU is not acceptable.')

    torch.backends.cudnn.benchmark = True #faster

    #if args.pretrained is None:
    if 'mel' in args.string_path:
        mean_stats = torch.FloatTensor(read_hdf5(args.stats, "/mean_melsp"))
        scale_stats = torch.FloatTensor(read_hdf5(args.stats, "/scale_melsp"))
        args.excit_dim = 0
        #mean_stats = torch.FloatTensor(np.r_[read_hdf5(args.stats, "/mean_feat_mceplf0cap")[:2], read_hdf5(args.stats, "/mean_melsp")])
        #scale_stats = torch.FloatTensor(np.r_[read_hdf5(args.stats, "/scale_feat_mceplf0cap")[:2], read_hdf5(args.stats, "/scale_melsp")])
        #args.excit_dim = 2
        #mean_stats = torch.FloatTensor(np.r_[read_hdf5(args.stats, "/mean_feat_mceplf0cap")[:6], read_hdf5(args.stats, "/mean_melsp")])
        #scale_stats = torch.FloatTensor(np.r_[read_hdf5(args.stats, "/scale_feat_mceplf0cap")[:6], read_hdf5(args.stats, "/scale_melsp")])
        #args.excit_dim = 6
    else:
        mean_stats = torch.FloatTensor(read_hdf5(args.stats, "/mean_"+args.string_path.replace("/","")))
        scale_stats = torch.FloatTensor(read_hdf5(args.stats, "/scale_"+args.string_path.replace("/","")))
        if mean_stats.shape[0] > args.mcep_dim+2:
            if 'feat_org_lf0' in args.string_path:
                args.cap_dim = mean_stats.shape[0]-(args.mcep_dim+2)
                args.excit_dim = 2+args.cap_dim
            else:
                args.cap_dim = mean_stats.shape[0]-(args.mcep_dim+3)
                args.excit_dim = 2+1+args.cap_dim
            #args.cap_dim = mean_stats.shape[0]-(args.mcep_dim+2)
            #args.excit_dim = 2+args.cap_dim
        else:
            args.cap_dim = None
            args.excit_dim = 2
    #else:
    #    if 'mel' in args.string_path:
    #        args.excit_dim = 0
    #    else:
    #        args.cap_dim = 3
    #        if 'legacy' not in args.string_path:
    #            args.excit_dim = 6
    #        else:
    #            args.excit_dim = 5

    # save args as conf
    # 14/15-8 or 14/15/16-6/7/8 [5ms]
    # 7-8 or 8-6/7/8 [10ms]
    #args.batch_size = 7
    #args.batch_size_utt = 8
    #args.batch_size = 8
    #args.batch_size_utt = 6
    #args.codeap_dim = 3
    torch.save(args, args.expdir + "/model.conf")
    #args.batch_size = 10
    #batch_sizes = [None]*3
    #batch_sizes[0] = int(args.batch_size*0.5)
    #batch_sizes[1] = int(args.batch_size)
    #batch_sizes[2] = int(args.batch_size*1.5)
    #logging.info(batch_sizes)

    # define network
    model_waveform = GRU_WAVE_DECODER_DUALGRU_COMPACT(
        feat_dim=args.mcep_dim+args.excit_dim,
        upsampling_factor=args.upsampling_factor,
        hidden_units=args.hidden_units_wave,
        hidden_units_2=args.hidden_units_wave_2,
        kernel_size=args.kernel_size_wave,
        dilation_size=args.dilation_size_wave,
        n_quantize=args.n_quantize,
        causal_conv=args.causal_conv_wave,
        lpc=args.lpc,
        right_size=args.right_size,
        do_prob=args.do_prob)
    logging.info(model_waveform)
    criterion_ce = torch.nn.CrossEntropyLoss(reduction='none')
    criterion_l1 = torch.nn.L1Loss(reduction='none')

    # send to gpu
    if torch.cuda.is_available():
        model_waveform.cuda()
        criterion_ce.cuda()
        criterion_l1.cuda()
        if args.pretrained is None:
            mean_stats = mean_stats.cuda()
            scale_stats = scale_stats.cuda()
    else:
        logging.error("gpu is not available. please check the setting.")
        sys.exit(1)

    model_waveform.train()

    if args.pretrained is None:
        model_waveform.scale_in.weight = torch.nn.Parameter(torch.unsqueeze(torch.diag(1.0/scale_stats.data),2))
        model_waveform.scale_in.bias = torch.nn.Parameter(-(mean_stats.data/scale_stats.data))

    for param in model_waveform.parameters():
        param.requires_grad = True
    for param in model_waveform.scale_in.parameters():
        param.requires_grad = False
    if args.lpc > 0:
        for param in model_waveform.logits.parameters():
            param.requires_grad = False

    parameters = filter(lambda p: p.requires_grad, model_waveform.parameters())
    parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000
    logging.info('Trainable Parameters (waveform): %.3f million' % parameters)

    module_list = list(model_waveform.conv.parameters())
    module_list += list(model_waveform.conv_s_c.parameters()) + list(model_waveform.embed_wav.parameters())
    module_list += list(model_waveform.gru.parameters()) + list(model_waveform.gru_2.parameters())
    module_list += list(model_waveform.out.parameters())

    optimizer = RAdam(module_list, lr=args.lr)
    #optimizer = torch.optim.Adam(module_list, lr=args.lr)
    #if args.pretrained is None:
    #    optimizer = RAdam(module_list, lr=args.lr)
    #else:
    #    #optimizer = RAdam(module_list, lr=args.lr)
    #    optimizer = torch.optim.Adam(module_list, lr=args.lr)

    # resume
    if args.pretrained is not None and args.resume is None:
        checkpoint = torch.load(args.pretrained)
        model_waveform.load_state_dict(checkpoint["model_waveform"])
    #    optimizer.load_state_dict(checkpoint["optimizer"])
        epoch_idx = checkpoint["iterations"]
        logging.info("pretrained from %d-iter checkpoint." % epoch_idx)
        epoch_idx = 0
    elif args.resume is not None:
        checkpoint = torch.load(args.resume)
        model_waveform.load_state_dict(checkpoint["model_waveform"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epoch_idx = checkpoint["iterations"]
        logging.info("restored from %d-iter checkpoint." % epoch_idx)
    else:
        epoch_idx = 0

    def zero_wav_pad(x): return padding(x, args.pad_len*args.upsampling_factor, value=0.0)  # noqa: E704
    def zero_feat_pad(x): return padding(x, args.pad_len, value=0.0)  # noqa: E704
    pad_wav_transform = transforms.Compose([zero_wav_pad])
    pad_feat_transform = transforms.Compose([zero_feat_pad])

    wav_transform = transforms.Compose([lambda x: encode_mu_law(x, args.n_quantize)])

    # define generator training
    if os.path.isdir(args.waveforms):
        filenames = sorted(find_files(args.waveforms, "*.wav", use_dir_name=False))
        wav_list = [args.waveforms + "/" + filename for filename in filenames]
    elif os.path.isfile(args.waveforms):
        wav_list = read_txt(args.waveforms)
    else:
        logging.error("--waveforms should be directory or list.")
        sys.exit(1)
    if os.path.isdir(args.feats):
        feat_list = [args.feats + "/" + filename for filename in filenames]
    elif os.path.isfile(args.feats):
        feat_list = read_txt(args.feats)
    else:
        logging.error("--feats should be directory or list.")
        sys.exit(1)
    assert len(wav_list) == len(feat_list)
    logging.info("number of training data = %d." % len(feat_list))
    dataset = FeatureDatasetNeuVoco(wav_list, feat_list, pad_wav_transform, pad_feat_transform, args.upsampling_factor, 
                    args.string_path, wav_transform=wav_transform)
                    #args.string_path, wav_transform=wav_transform, with_excit=True)
                    #args.string_path, wav_transform=wav_transform, with_excit=False)
                    #args.string_path, wav_transform=wav_transform, with_excit=True, codeap_dim=args.codeap_dim)
    dataloader = DataLoader(dataset, batch_size=args.batch_size_utt, shuffle=True, num_workers=args.n_workers)
    #generator = data_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=1)
    generator = data_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=None)
    #generator = data_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=None, batch_sizes=batch_sizes)

    # define generator evaluation
    if os.path.isdir(args.waveforms_eval):
        filenames = sorted(find_files(args.waveforms_eval, "*.wav", use_dir_name=False))
        wav_list_eval = [args.waveforms + "/" + filename for filename in filenames]
    elif os.path.isfile(args.waveforms_eval):
        wav_list_eval = read_txt(args.waveforms_eval)
    else:
        logging.error("--waveforms_eval should be directory or list.")
        sys.exit(1)
    if os.path.isdir(args.feats_eval):
        feat_list_eval = [args.feats_eval + "/" + filename for filename in filenames]
    elif os.path.isfile(args.feats):
        feat_list_eval = read_txt(args.feats_eval)
    else:
        logging.error("--feats_eval should be directory or list.")
        sys.exit(1)
    assert len(wav_list_eval) == len(feat_list_eval)
    logging.info("number of evaluation data = %d." % len(feat_list_eval))
    dataset_eval = FeatureDatasetNeuVoco(wav_list_eval, feat_list_eval, pad_wav_transform, pad_feat_transform, args.upsampling_factor, 
                    args.string_path, wav_transform=wav_transform)
                    #args.string_path, wav_transform=wav_transform, with_excit=False)
                    #args.string_path, wav_transform=wav_transform, with_excit=True)
                    #args.string_path, wav_transform=wav_transform, with_excit=True, codeap_dim=args.codeap_dim)
    dataloader_eval = DataLoader(dataset_eval, batch_size=args.batch_size_utt_eval, shuffle=False, num_workers=args.n_workers)
    #generator_eval = data_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=1)
    generator_eval = data_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=None)

    writer = SummaryWriter(args.expdir)
    total_train_loss = defaultdict(list)
    total_eval_loss = defaultdict(list)

    #sparsify = lpcnet.Sparsify(2000, 40000, 400, (0.05, 0.05, 0.2))
    density_deltas_ = args.densities.split('-')
    density_deltas = [None]*len(density_deltas_)
    for i in range(len(density_deltas_)):
        density_deltas[i] = (1-float(density_deltas_[i]))/args.n_stage
    t_deltas = [None]*args.n_stage
    t_starts = [None]*args.n_stage
    t_ends = [None]*args.n_stage
    densities = [None]*args.n_stage
    t_delta = args.t_end - args.t_start + 1
    #t_deltas[0] = round((1/(args.n_stage-1))*0.6*t_delta)
    if args.n_stage > 3:
        t_deltas[0] = round((1/2)*0.2*t_delta)
    else:
        t_deltas[0] = round(0.2*t_delta)
    t_starts[0] = args.t_start
    t_ends[0] = args.t_start + t_deltas[0] - 1
    densities[0] = [None]*len(density_deltas)
    for j in range(len(density_deltas)):
        densities[0][j] = 1-density_deltas[j]
    for i in range(1,args.n_stage):
        if i < args.n_stage-1:
            #t_deltas[i] = round((1/(args.n_stage-1))*0.6*t_delta)
            if args.n_stage > 3:
                if i < 2:
                    t_deltas[i] = round((1/2)*0.2*t_delta)
                else:
                    if args.n_stage > 4:
                        t_deltas[i] = round((1/2)*0.3*t_delta)
                    else:
                        t_deltas[i] = round(0.3*t_delta)
            else:
                t_deltas[i] = round(0.3*t_delta)
        else:
            #t_deltas[i] = round(0.4*t_delta)
            t_deltas[i] = round(0.5*t_delta)
        t_starts[i] = t_ends[i-1] + 1
        t_ends[i] = t_starts[i] + t_deltas[i] - 1
        densities[i] = [None]*len(density_deltas)
        if i < args.n_stage-1:
            for j in range(len(density_deltas)):
                densities[i][j] = densities[i-1][j]-density_deltas[j]
        else:
            for j in range(len(density_deltas)):
                densities[i][j] = float(density_deltas_[j])
    logging.info(t_delta)
    logging.info(t_deltas)
    logging.info(t_starts)
    logging.info(t_ends)
    logging.info(args.interval)
    logging.info(densities)
    idx_stage = 0

    # train
    total = 0
    iter_count = 0
    loss_ce = []
    loss_err = []
    min_eval_loss_ce = 99999999.99
    min_eval_loss_ce_std = 99999999.99
    min_eval_loss_err = 99999999.99
    min_eval_loss_err_std = 99999999.99
    iter_idx = 0
    min_idx = -1
    #min_eval_loss_ce = 2.007181
    #min_eval_loss_ce_std = 0.801412
    #iter_idx = 70350
    #min_idx = 6 #resume7
    while idx_stage < args.n_stage-1 and iter_idx + 1 >= t_starts[idx_stage+1]:
        idx_stage += 1
        logging.info(idx_stage)
    change_min_flag = False
    if args.resume is not None:
        np.random.set_state(checkpoint["numpy_random_state"])
        torch.set_rng_state(checkpoint["torch_random_state"])
    logging.info("==%d EPOCH==" % (epoch_idx+1))
    logging.info("Training data")
    while epoch_idx < args.epoch_count:
        start = time.time()
        batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \
            del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator)
        if c_idx < 0: # summarize epoch
            # save current epoch model
            numpy_random_state = np.random.get_state()
            torch_random_state = torch.get_rng_state()
            # report current epoch
            logging.info("(EPOCH:%d) average optimization loss = %.6f (+- %.6f) %.6f (+- %.6f) %% ;; "\
                "(%.3f min., %.3f sec / batch)" % (epoch_idx + 1, np.mean(loss_ce), np.std(loss_ce), \
                    np.mean(loss_err), np.std(loss_err), total / 60.0, total / iter_count))
            logging.info("estimated time until max. epoch = {0.days:02}:{0.hours:02}:{0.minutes:02}:"\
            "{0.seconds:02}".format(relativedelta(seconds=int((args.epoch_count - (epoch_idx + 1)) * total))))
            # compute loss in evaluation data
            total = 0
            iter_count = 0
            loss_ce = []
            loss_err = []
            model_waveform.eval()
            for param in model_waveform.parameters():
                param.requires_grad = False
            logging.info("Evaluation data")
            while True:
                with torch.no_grad():
                    start = time.time()
                    batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \
                        del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator_eval)
                    if c_idx < 0:
                        break

                    x_es = x_ss+x_bs
                    f_es = f_ss+f_bs
                    logging.info(f'{x_ss} {x_bs} {x_es} {f_ss} {f_bs} {f_es} {max_slen}')
                    if x_ss > 0:
                        if x_es <= max_slen:
                            batch_x_prev = batch_x[:,x_ss-1:x_es-1]
                            if args.lpc > 0:
                                if x_ss-args.lpc >= 0:
                                    batch_x_lpc = batch_x[:,x_ss-args.lpc:x_es-1]
                                else:
                                    batch_x_lpc = F.pad(batch_x[:,:x_es-1], (-(x_ss-args.lpc), 0), "constant", args.n_quantize // 2)
                            batch_feat = batch_feat[:,f_ss:f_es]
                            batch_x = batch_x[:,x_ss:x_es]
                        else:
                            batch_x_prev = batch_x[:,x_ss-1:-1]
                            if args.lpc > 0:
                                if x_ss-args.lpc >= 0:
                                    batch_x_lpc = batch_x[:,x_ss-args.lpc:-1]
                                else:
                                    batch_x_lpc = F.pad(batch_x[:,:-1], (-(x_ss-args.lpc), 0), "constant", args.n_quantize // 2)
                            batch_feat = batch_feat[:,f_ss:]
                            batch_x = batch_x[:,x_ss:]
                    else:
                        batch_x_prev = F.pad(batch_x[:,:x_es-1], (1, 0), "constant", args.n_quantize // 2)
                        if args.lpc > 0:
                            batch_x_lpc = F.pad(batch_x[:,:x_es-1], (args.lpc, 0), "constant", args.n_quantize // 2)
                        batch_feat = batch_feat[:,:f_es]
                        batch_x = batch_x[:,:x_es]
                    #assert((batch_x_prev[:,1:] == batch_x[:,:-1]).all())

                    if f_ss > 0:
                        if len(del_index_utt) > 0:
                            h_x = torch.FloatTensor(np.delete(h_x.cpu().data.numpy(), del_index_utt, axis=1)).to(device)
                            h_x_2 = torch.FloatTensor(np.delete(h_x_2.cpu().data.numpy(), del_index_utt, axis=1)).to(device)
                        if args.lpc > 0:
                            batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, h=h_x, h_2=h_x_2, x_lpc=batch_x_lpc)
                        else:
                            batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, h=h_x, h_2=h_x_2)
                    else:
                        if args.lpc > 0:
                            batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, x_lpc=batch_x_lpc)
                        else:
                            batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev)

                    # samples check
                    i = np.random.randint(0, batch_x_output.shape[0])
                    logging.info("%s" % (os.path.join(os.path.basename(os.path.dirname(featfile[i])),os.path.basename(featfile[i]))))
                    #check_samples = batch_x[i,5:10].long()
                    #logging.info(torch.index_select(F.softmax(batch_x_output[i,5:10], dim=-1), 1, check_samples))
                    #logging.info(check_samples)

                    # handle short ending
                    if len(idx_select) > 0:
                        logging.info('len_idx_select: '+str(len(idx_select)))
                        batch_loss_ce_select = 0
                        batch_loss_err_select = 0
                        for j in range(len(idx_select)):
                            k = idx_select[j]
                            slens_utt = slens_acc[k]
                            logging.info('%s %d' % (featfile[k], slens_utt))
                            batch_x_output_ = batch_x_output[k,:slens_utt]
                            batch_x_ = batch_x[k,:slens_utt]
                            batch_loss_ce_select += torch.mean(criterion_ce(batch_x_output_, batch_x_))
                            batch_loss_err_select += torch.mean(torch.sum(100*criterion_l1(F.softmax(batch_x_output_, dim=-1), F.one_hot(batch_x_, num_classes=args.n_quantize).float()), -1))
                        batch_loss += batch_loss_ce_select
                        batch_loss_ce_select /= len(idx_select)
                        batch_loss_err_select /= len(idx_select)
                        total_eval_loss["eval/loss_ce"].append(batch_loss_ce_select.item())
                        total_eval_loss["eval/loss_err"].append(batch_loss_err_select.item())
                        loss_ce.append(batch_loss_ce_select.item())
                        loss_err.append(batch_loss_err_select.item())
                        if len(idx_select_full) > 0:
                            logging.info('len_idx_select_full: '+str(len(idx_select_full)))
                            batch_x = torch.index_select(batch_x,0,idx_select_full)
                            batch_x_output = torch.index_select(batch_x_output,0,idx_select_full)
                        else:
                            logging.info("batch loss select %.3f %.3f (%.3f sec)" % (batch_loss_ce_select.item(), \
                                batch_loss_err_select.item(), time.time() - start))
                            iter_count += 1
                            total += time.time() - start
                            continue

                    # loss
                    batch_loss_ce_ = torch.mean(criterion_ce(batch_x_output.reshape(-1, args.n_quantize), batch_x.reshape(-1)).reshape(batch_x_output.shape[0], -1), -1)
                    batch_loss_err_ = torch.mean(torch.sum(100*criterion_l1(F.softmax(batch_x_output, dim=-1), F.one_hot(batch_x, num_classes=args.n_quantize).float()), -1), -1)

                    batch_loss_ce = batch_loss_ce_.mean()
                    batch_loss_err = batch_loss_err_.mean()
                    total_eval_loss["eval/loss_ce"].append(batch_loss_ce.item())
                    total_eval_loss["eval/loss_err"].append(batch_loss_err.item())
                    loss_ce.append(batch_loss_ce.item())
                    loss_err.append(batch_loss_err.item())

                    logging.info("batch eval loss [%d] %d %d %d %d %d : %.3f %.3f %% (%.3f sec)" % (c_idx+1, max_slen, x_ss, x_bs, \
                        f_ss, f_bs, batch_loss_ce.item(), batch_loss_err.item(), time.time() - start))
                    iter_count += 1
                    total += time.time() - start
            logging.info('sme')
            for key in total_eval_loss.keys():
                total_eval_loss[key] = np.mean(total_eval_loss[key])
                logging.info(f"(Steps: {iter_idx}) {key} = {total_eval_loss[key]:.4f}.")
            write_to_tensorboard(writer, iter_idx, total_eval_loss)
            total_eval_loss = defaultdict(list)
            eval_loss_ce = np.mean(loss_ce)
            eval_loss_ce_std = np.std(loss_ce)
            eval_loss_err = np.mean(loss_err)
            eval_loss_err_std = np.std(loss_err)
            logging.info("(EPOCH:%d) average evaluation loss = %.6f (+- %.6f) %.6f (+- %.6f) %% ;; (%.3f min., "\
                "%.3f sec / batch)" % (epoch_idx + 1, eval_loss_ce, eval_loss_ce_std, \
                    eval_loss_err, eval_loss_err_std, total / 60.0, total / iter_count))
            if (eval_loss_ce+eval_loss_ce_std) <= (min_eval_loss_ce+min_eval_loss_ce_std) \
                or (eval_loss_ce <= min_eval_loss_ce):
                min_eval_loss_ce = eval_loss_ce
                min_eval_loss_ce_std = eval_loss_ce_std
                min_eval_loss_err = eval_loss_err
                min_eval_loss_err_std = eval_loss_err_std
                min_idx = epoch_idx
                change_min_flag = True
            if change_min_flag:
                logging.info("min_eval_loss = %.6f (+- %.6f) %.6f (+- %.6f) %% min_idx=%d" % (min_eval_loss_ce, \
                    min_eval_loss_ce_std, min_eval_loss_err, min_eval_loss_err_std, min_idx+1))
            #if ((epoch_idx + 1) % args.save_interval_epoch == 0) or (epoch_min_flag):
            #    logging.info('save epoch:%d' % (epoch_idx+1))
            #    save_checkpoint(args.expdir, model_waveform, optimizer, numpy_random_state, torch_random_state, epoch_idx + 1)
            logging.info('save epoch:%d' % (epoch_idx+1))
            save_checkpoint(args.expdir, model_waveform, optimizer, numpy_random_state, torch_random_state, epoch_idx + 1)
            total = 0
            iter_count = 0
            loss_ce = []
            loss_err = []
            epoch_idx += 1
            np.random.set_state(numpy_random_state)
            torch.set_rng_state(torch_random_state)
            model_waveform.train()
            for param in model_waveform.parameters():
                param.requires_grad = True
            for param in model_waveform.scale_in.parameters():
                param.requires_grad = False
            if args.lpc > 0:
                for param in model_waveform.logits.parameters():
                    param.requires_grad = False
            # start next epoch
            if epoch_idx < args.epoch_count:
                start = time.time()
                logging.info("==%d EPOCH==" % (epoch_idx+1))
                logging.info("Training data")
                batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \
                    del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator)
        # feedforward and backpropagate current batch
        if epoch_idx < args.epoch_count:
            logging.info("%d iteration [%d]" % (iter_idx+1, epoch_idx+1))

            x_es = x_ss+x_bs
            f_es = f_ss+f_bs
            logging.info(f'{x_ss} {x_bs} {x_es} {f_ss} {f_bs} {f_es} {max_slen}')
            if x_ss > 0:
                if x_es <= max_slen:
                    batch_x_prev = batch_x[:,x_ss-1:x_es-1]
                    if args.lpc > 0:
                        if x_ss-args.lpc >= 0:
                            batch_x_lpc = batch_x[:,x_ss-args.lpc:x_es-1]
                        else:
                            batch_x_lpc = F.pad(batch_x[:,:x_es-1], (-(x_ss-args.lpc), 0), "constant", args.n_quantize // 2)
                    batch_feat = batch_feat[:,f_ss:f_es]
                    batch_x = batch_x[:,x_ss:x_es]
                else:
                    batch_x_prev = batch_x[:,x_ss-1:-1]
                    if args.lpc > 0:
                        if x_ss-args.lpc >= 0:
                            batch_x_lpc = batch_x[:,x_ss-args.lpc:-1]
                        else:
                            batch_x_lpc = F.pad(batch_x[:,:-1], (-(x_ss-args.lpc), 0), "constant", args.n_quantize // 2)
                    batch_feat = batch_feat[:,f_ss:]
                    batch_x = batch_x[:,x_ss:]
            else:
                batch_x_prev = F.pad(batch_x[:,:x_es-1], (1, 0), "constant", args.n_quantize // 2)
                if args.lpc > 0:
                    batch_x_lpc = F.pad(batch_x[:,:x_es-1], (args.lpc, 0), "constant", args.n_quantize // 2)
                batch_feat = batch_feat[:,:f_es]
                batch_x = batch_x[:,:x_es]
            #assert((batch_x_prev[:,1:] == batch_x[:,:-1]).all())

            if f_ss > 0:
                if len(del_index_utt) > 0:
                    h_x = torch.FloatTensor(np.delete(h_x.cpu().data.numpy(), del_index_utt, axis=1)).to(device)
                    h_x_2 = torch.FloatTensor(np.delete(h_x_2.cpu().data.numpy(), del_index_utt, axis=1)).to(device)
                if args.lpc > 0:
                    batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, h=h_x, h_2=h_x_2, x_lpc=batch_x_lpc, do=True)
                else:
                    batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, h=h_x, h_2=h_x_2, do=True)
            else:
                if args.lpc > 0:
                    batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, x_lpc=batch_x_lpc, do=True)
                else:
                    batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, do=True)

            # samples check
            #with torch.no_grad():
            i = np.random.randint(0, batch_x_output.shape[0])
            logging.info("%s" % (os.path.join(os.path.basename(os.path.dirname(featfile[i])),os.path.basename(featfile[i]))))
            #    check_samples = batch_x[i,5:10].long()
            #    logging.info(torch.index_select(F.softmax(batch_x_output[i,5:10], dim=-1), 1, check_samples))
            #    logging.info(check_samples)

            # handle short ending
            batch_loss = 0
            if len(idx_select) > 0:
                logging.info('len_idx_select: '+str(len(idx_select)))
                batch_loss_ce_select = 0
                batch_loss_err_select = 0
                for j in range(len(idx_select)):
                    k = idx_select[j]
                    slens_utt = slens_acc[k]
                    logging.info('%s %d' % (featfile[k], slens_utt))
                    batch_x_output_ = batch_x_output[k,:slens_utt]
                    batch_x_ = batch_x[k,:slens_utt]
                    batch_loss_ce_select += torch.mean(criterion_ce(batch_x_output_, batch_x_))
                    batch_loss_err_select += torch.mean(torch.sum(100*criterion_l1(F.softmax(batch_x_output_, dim=-1), F.one_hot(batch_x_, num_classes=args.n_quantize).float()), -1))
                batch_loss += batch_loss_ce_select
                batch_loss_ce_select /= len(idx_select)
                batch_loss_err_select /= len(idx_select)
                total_train_loss["train/loss_ce"].append(batch_loss_ce_select.item())
                total_train_loss["train/loss_err"].append(batch_loss_err_select.item())
                loss_ce.append(batch_loss_ce_select.item())
                loss_err.append(batch_loss_err_select.item())
                if len(idx_select_full) > 0:
                    logging.info('len_idx_select_full: '+str(len(idx_select_full)))
                    batch_x = torch.index_select(batch_x,0,idx_select_full)
                    batch_x_output = torch.index_select(batch_x_output,0,idx_select_full)
                #elif len(idx_select) > 1:
                else:
                    optimizer.zero_grad()
                    batch_loss.backward()
                    #for name, param in model_waveform.named_parameters():
                    #    if param.requires_grad:
                    #        logging.info(f"{name} {param.grad.norm()}")
                    flag = False
                    for name, param in model_waveform.named_parameters():
                        if param.requires_grad:
                            grad_norm = param.grad.norm()
                    #        logging.info(f"{name} {grad_norm}")
                            #if grad_norm >= 1e4 or torch.isnan(grad_norm) or torch.isinf(grad_norm):
                            if torch.isnan(grad_norm) or torch.isinf(grad_norm):
                                flag = True
                    if flag:
                        logging.info("explode grad")
                        optimizer.zero_grad()
                        continue
                    torch.nn.utils.clip_grad_norm_(model_waveform.parameters(), 10)
                    #for name, param in model_waveform.named_parameters():
                    #    if param.requires_grad:
                    #        logging.info(f"{name} {param.grad.norm()}")
                    optimizer.step()

                    with torch.no_grad():
                        #test = model_waveform.gru.weight_hh_l0.data.clone()
                        #sparsify = lpcnet.Sparsify(2000, 40000, 400, (0.05, 0.05, 0.2))
                        #t_start, t_end, interval, densities
                        if idx_stage < args.n_stage-1 and iter_idx + 1 == t_starts[idx_stage+1]:
                            idx_stage += 1
                        if idx_stage > 0:
                            sparsify(model_waveform, iter_idx + 1, t_starts[idx_stage], t_ends[idx_stage], args.interval, densities[idx_stage], densities_p=densities[idx_stage-1])
                        else:
                            sparsify(model_waveform, iter_idx + 1, t_starts[idx_stage], t_ends[idx_stage], args.interval, densities[idx_stage])
                        #logging.info((test==model_waveform.gru.weight_hh_l0).all())

                    logging.info("batch loss select %.3f %.3f (%.3f sec)" % (batch_loss_ce_select.item(), \
                        batch_loss_err_select.item(), time.time() - start))
                    iter_idx += 1
                    if iter_idx % args.save_interval_iter == 0:
                        logging.info('save iter:%d' % (iter_idx))
                        save_checkpoint(args.expdir, model_waveform, optimizer, np.random.get_state(), torch.get_rng_state(), iter_idx)
                    iter_count += 1
                    if iter_idx % args.log_interval_steps == 0:
                        logging.info('smt')
                        for key in total_train_loss.keys():
                            total_train_loss[key] = np.mean(total_train_loss[key])
                            logging.info(f"(Steps: {iter_idx}) {key} = {total_train_loss[key]:.4f}.")
                        write_to_tensorboard(writer, iter_idx, total_train_loss)
                        total_train_loss = defaultdict(list)
                    total += time.time() - start
                    continue
                #else:
                #    continue

            # loss
            batch_loss_ce_ = torch.mean(criterion_ce(batch_x_output.reshape(-1, args.n_quantize), batch_x.reshape(-1)).reshape(batch_x_output.shape[0], -1), -1)
            batch_loss_err_ = torch.mean(torch.sum(100*criterion_l1(F.softmax(batch_x_output, dim=-1), F.one_hot(batch_x, num_classes=args.n_quantize).float()), -1), -1)

            batch_loss_ce = batch_loss_ce_.mean()
            batch_loss_err = batch_loss_err_.mean()
            total_train_loss["train/loss_ce"].append(batch_loss_ce.item())
            total_train_loss["train/loss_err"].append(batch_loss_err.item())
            loss_ce.append(batch_loss_ce.item())
            loss_err.append(batch_loss_err.item())

            batch_loss += batch_loss_ce_.sum()

            optimizer.zero_grad()
            batch_loss.backward()
            #for name, param in model_waveform.named_parameters():
            #    if param.requires_grad:
            #        logging.info(f"{name} {param.grad.norm()}")
            flag = False
            for name, param in model_waveform.named_parameters():
                if param.requires_grad:
                    grad_norm = param.grad.norm()
            #        logging.info(f"{name} {grad_norm}")
                    #if grad_norm >= 1e4 or torch.isnan(grad_norm) or torch.isinf(grad_norm):
                    if torch.isnan(grad_norm) or torch.isinf(grad_norm):
                        flag = True
            if flag:
                logging.info("explode grad")
                optimizer.zero_grad()
                continue
            torch.nn.utils.clip_grad_norm_(model_waveform.parameters(), 10)
            #for name, param in model_waveform.named_parameters():
            #    if param.requires_grad:
            #        logging.info(f"{name} {param.grad.norm()}")
            optimizer.step()

            with torch.no_grad():
                #test = model_waveform.gru.weight_hh_l0.data.clone()
                #sparsify = lpcnet.Sparsify(2000, 40000, 400, (0.05, 0.05, 0.2))
                #t_start, t_end, interval, densities
                if idx_stage < args.n_stage-1 and iter_idx + 1 == t_starts[idx_stage+1]:
                    idx_stage += 1
                if idx_stage > 0:
                    sparsify(model_waveform, iter_idx + 1, t_starts[idx_stage], t_ends[idx_stage], args.interval, densities[idx_stage], densities_p=densities[idx_stage-1])
                else:
                    sparsify(model_waveform, iter_idx + 1, t_starts[idx_stage], t_ends[idx_stage], args.interval, densities[idx_stage])
                #logging.info((test==model_waveform.gru.weight_hh_l0).all())

            logging.info("batch loss [%d] %d %d %d %d %d : %.3f %.3f %% (%.3f sec)" % (c_idx+1, max_slen, x_ss, x_bs, \
                f_ss, f_bs, batch_loss_ce.item(), batch_loss_err.item(), time.time() - start))
            iter_idx += 1
            if iter_idx % args.save_interval_iter == 0:
                logging.info('save iter:%d' % (iter_idx))
                save_checkpoint(args.expdir, model_waveform, optimizer, np.random.get_state(), torch.get_rng_state(), iter_idx)
            iter_count += 1
            if iter_idx % args.log_interval_steps == 0:
                logging.info('smt')
                for key in total_train_loss.keys():
                    total_train_loss[key] = np.mean(total_train_loss[key])
                    logging.info(f"(Steps: {iter_idx}) {key} = {total_train_loss[key]:.4f}.")
                write_to_tensorboard(writer, iter_idx, total_train_loss)
                total_train_loss = defaultdict(list)
            total += time.time() - start


    # save final model
    model_waveform.cpu()
    torch.save({"model_waveform": model_waveform.state_dict()}, args.expdir + "/checkpoint-final.pkl")
    logging.info("final checkpoint created.")
Пример #6
0
def train(args):
    print('start training...')
    model, model_file = create_model(args)
    train_loader, val_loader = get_train_val_loaders(
        batch_size=args.batch_size, val_batch_size=args.val_batch_size)
    train_loader = get_frame_train_loader(batch_size=args.batch_size)
    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1",verbosity=0)

    if args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=0.0001)
    elif args.optim == 'RAdam':
        optimizer = RAdam(model.parameters(), lr=args.lr, weight_decay=0.0001)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=0.0001)

    if args.lrs == 'plateau':
        lr_scheduler = ReduceLROnPlateau(optimizer,
                                         mode='min',
                                         factor=args.factor,
                                         patience=args.patience,
                                         min_lr=args.min_lr)
    else:
        lr_scheduler = CosineAnnealingLR(optimizer,
                                         args.t_max,
                                         eta_min=args.min_lr)

    model = model.cuda()
    if torch.cuda.device_count() > 1:
        model_name = model.name
        model = DataParallel(model)
        model.name = model_name

    #model=model.train()

    best_f2 = 99999.
    best_key = 'loss'

    print(
        'epoch |    lr     |       %        |  loss  |  avg   |  loss  |  0.01  |  0.20  |  0.50  |  best  | time |  save |'
    )

    if not args.no_first_val:
        val_metrics = validate(args, model, val_loader)
        print(
            'val   |           |                |        |        | {:.4f} | {:.4f} | {:.4f} | {:.4f} | {:.4f} |       |        |'
            .format(val_metrics['loss'], val_metrics['f2_th_0.01'],
                    val_metrics['f2_th_0.20'], val_metrics['f2_th_0.50'],
                    val_metrics[best_key]))

        best_f2 = val_metrics[best_key]

    if args.val:
        return

    model.train()

    if args.lrs == 'plateau':
        lr_scheduler.step(best_f2)
    else:
        lr_scheduler.step()

    train_iter = 0

    for epoch in range(args.start_epoch, args.num_epochs):
        #train_loader, val_loader = get_train_val_loaders(batch_size=args.batch_size, val_batch_size=args.val_batch_size, val_num=args.val_num)

        train_loss = 0

        current_lr = get_lrs(optimizer)
        bg = time.time()
        for batch_idx, data in enumerate(train_loader):
            train_iter += 1
            if train_loader.seg:
                rgb, audio, labels = [x.cuda() for x in data]
            else:
                rgb, audio, labels = data[0].cuda(), data[2].cuda(
                ), data[4].cuda()

            output = model(rgb, audio)

            loss = criterion(output, labels)
            batch_size = rgb.size(0)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            #with amp.scale_loss(loss, optimizer) as scaled_loss:
            #    scaled_loss.backward()

            train_loss += loss.item()
            print('\r {:4d} | {:.7f} | {:06d}/{} | {:.4f} | {:.4f} |'.format(
                epoch, float(current_lr[0]), args.batch_size * (batch_idx + 1),
                train_loader.num, loss.item(), train_loss / (batch_idx + 1)),
                  end='')

            if train_iter > 0 and train_iter % args.iter_val == 0:
                if isinstance(model, DataParallel):
                    torch.save(model.module.state_dict(),
                               model_file + '_latest')
                else:
                    torch.save(model.state_dict(), model_file + '_latest')

                val_metrics = validate(args, model, val_loader)

                _save_ckp = ''
                if args.always_save or val_metrics[best_key] < best_f2:
                    best_f2 = val_metrics[best_key]
                    if isinstance(model, DataParallel):
                        torch.save(model.module.state_dict(), model_file)
                    else:
                        torch.save(model.state_dict(), model_file)
                    _save_ckp = '*'
                print(
                    ' {:.4f} | {:.4f} | {:.4f} | {:.4f} | {:.4f} | {:.2f} |  {:4s} |'
                    .format(val_metrics['loss'], val_metrics['f2_th_0.01'],
                            val_metrics['f2_th_0.20'],
                            val_metrics['f2_th_0.50'], best_f2,
                            (time.time() - bg) / 60, _save_ckp))

                model.train()
                if args.lrs == 'plateau':
                    lr_scheduler.step(best_f2)
                else:
                    lr_scheduler.step()
                current_lr = get_lrs(optimizer)
Пример #7
0
class DDPG(nn.Module):
    def __init__(
        self,
        d_state,
        d_action,
        device,
        gamma,
        tau,
        policy_lr,
        value_lr,
        value_loss,
        value_n_layers,
        value_n_units,
        value_activation,
        policy_n_layers,
        policy_n_units,
        policy_activation,
        grad_clip,
        policy_noise=0.2,
        noise_clip=0.5,
        expl_noise=0.1,
        tdg_error_weight=0,
        td_error_weight=1,
    ):
        super().__init__()

        self.actor = Actor(d_state, d_action, policy_n_layers, policy_n_units,
                           policy_activation).to(device)
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = RAdam(self.actor.parameters(), lr=policy_lr)

        self.critic = ActionValueFunction(d_state, d_action, value_n_layers,
                                          value_n_units,
                                          value_activation).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = RAdam(self.critic.parameters(), lr=value_lr)

        self.discount = gamma
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.expl_noise = expl_noise
        self.normalizer = None
        self.value_loss = value_loss
        self.grad_clip = grad_clip
        self.device = device

        self.tdg_error_weight = tdg_error_weight
        self.td_error_weight = td_error_weight
        self.step_counter = 0

    def setup_normalizer(self, normalizer):
        self.normalizer = copy.deepcopy(normalizer)

    def get_action(self, states, deterministic=False):
        states = states.to(self.device)
        if self.normalizer is not None:
            states = self.normalizer.normalize_states(states)
        actions = self.actor(states)
        if not deterministic:
            actions = actions + torch.randn_like(actions) * self.expl_noise
        return actions.clamp(-1, +1)

    def get_action_with_logp(self, states):
        states = states.to(self.device)
        if self.normalizer is not None:
            states = self.normalizer.normalize_states(states)
        a = self.actor(states)
        return a, torch.ones(
            a.shape[0], device=a.device) * np.inf  # inf: should not be used

    def get_action_value(self, states, actions):
        if self.normalizer is not None:
            states = self.normalizer.normalize_states(states)
        with torch.no_grad():
            states = states.to(self.device)
            actions = actions.to(self.device)
            return self.critic(states, actions)[0]  # just q1

    def update(self, states, actions, logps, rewards, next_states, masks):
        if self.normalizer is not None:
            states = self.normalizer.normalize_states(states)
            next_states = self.normalizer.normalize_states(next_states)
        self.step_counter += 1

        # Select action according to policy and add clipped noise
        noise = (torch.randn_like(actions) * self.policy_noise).clamp(
            -self.noise_clip, self.noise_clip)
        raw_next_actions = self.actor_target(next_states)
        next_actions = (raw_next_actions + noise).clamp(-1, 1)

        # Compute the target Q value
        next_Q = self.critic_target(next_states, next_actions)
        q_target = rewards.unsqueeze(
            1) + self.discount * masks.float().unsqueeze(1) * next_Q
        zero_targets = torch.zeros_like(q_target, device=self.device)

        q = self.critic(states, actions)  # Q(s,a)
        q_td_error = q_target - q
        critic_loss, standard_loss, gradient_loss = torch.tensor(
            0, device=self.device), torch.tensor(
                0, device=self.device), torch.tensor(0, device=self.device)
        if self.td_error_weight > 0:
            if self.value_loss == 'huber':
                standard_loss = 0.5 * F.smooth_l1_loss(q_td_error,
                                                       zero_targets)
            elif self.value_loss == 'mse':
                standard_loss = 0.5 * F.mse_loss(q_td_error, zero_targets)
            critic_loss = critic_loss + self.td_error_weight * standard_loss
        if self.tdg_error_weight > 0:
            gradients_error_norms = torch.autograd.grad(
                outputs=q_td_error,
                inputs=actions,
                grad_outputs=torch.ones(q_td_error.size(), device=self.device),
                retain_graph=True,
                create_graph=True,
                only_inputs=True)[0].flatten(start_dim=1).norm(dim=1,
                                                               keepdim=True)
            if self.value_loss == 'huber':
                gradient_loss = 0.5 * F.smooth_l1_loss(gradients_error_norms,
                                                       zero_targets)
            elif self.value_loss == 'mse':
                gradient_loss = 0.5 * F.mse_loss(gradients_error_norms,
                                                 zero_targets)
            critic_loss = critic_loss + self.tdg_error_weight * gradient_loss

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward(retain_graph=True)
        torch.nn.utils.clip_grad_value_(self.critic.parameters(),
                                        self.grad_clip)
        self.critic_optimizer.step()

        # Compute actor loss
        q = self.critic(states, self.actor(states))  # Q(s,pi(s))
        actor_loss = -q.mean()

        # Optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        torch.nn.utils.clip_grad_value_(self.actor.parameters(),
                                        self.grad_clip)
        self.actor_optimizer.step()

        # Update the frozen target models
        for param, target_param in zip(self.critic.parameters(),
                                       self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data +
                                    (1 - self.tau) * target_param.data)

        for param, target_param in zip(self.actor.parameters(),
                                       self.actor_target.parameters()):
            target_param.data.copy_(self.tau * param.data +
                                    (1 - self.tau) * target_param.data)

        return raw_next_actions[
            0, 0].item(), self.td_error_weight * standard_loss.item(
            ), self.tdg_error_weight * gradient_loss.item(), actor_loss.item()

    @staticmethod
    def catastrophic_divergence(q_loss, pi_loss):
        return q_loss > 1e2 or (pi_loss is not None and abs(pi_loss) > 1e5)
Пример #8
0
class Main(FlyAI):
    '''
    项目中必须继承FlyAI类,否则线上运行会报错。
    '''
    def __init__(self, model_name):
        self.num_classes = 2
        # create model
        self.model_name = model_name
        self.model = get_net(model_name, self.num_classes)
        if use_gpu:
            self.model.to(DEVICE)
        # 超参数设置
        # self.criteration = LSRCrossEntropyLossV2(lb_smooth=0.2, lb_ignore=255)
        self.criteration = HybridCappaLoss()
        self.optimizer = RAdam(params=self.model.parameters(),
                               lr=0.003,
                               weight_decay=0.0001)
        milestones = [5 + x * 30 for x in range(5)]
        print(f'milestones:{milestones}')
        scheduler_c = CyclicCosAnnealingLR(self.optimizer,
                                           milestones=milestones,
                                           eta_min=5e-5)
        # # scheduler_r = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.2, patience=4, verbose=True)
        self.scheduler = LearningRateWarmUP(optimizer=self.optimizer,
                                            target_iteration=5,
                                            target_lr=0.003,
                                            after_scheduler=scheduler_c)
        self.mix_up = False
        if self.mix_up:
            print("using mix_up")
        self.cutMix = False
        if self.cutMix:
            print("using cutMix")
        self.fmix = True
        if self.fmix:
            print("using fmix")

    def download_data(self):
        # 根据数据ID下载训练数据
        data_helper = DataHelper()
        data_helper.download_from_ids(DataID)

    def train_one_epoch(self, train_loader, val_loader):
        self.model.train()
        train_loss_sum, train_acc_sum = 0.0, 0.0
        for img, label in train_loader:
            # if len(label) <= 1:
            #     continue
            img, label = img.to(DEVICE), label.to(DEVICE)
            width, height = img.size(-1), img.size(-2)
            self.optimizer.zero_grad()
            if self.mix_up:
                img, labels_a, labels_b, lam = mixup_data(img,
                                                          label,
                                                          alpha=0.2)
                output = self.model(img)
                loss = mixup_criterion(self.criteration, output, labels_a,
                                       labels_b, lam)
            elif self.cutMix:
                img, targets = cutmix(img, label)
                target_a, target_b, lam = targets
                output = self.model(img)
                loss = self.criteration(output,
                                        target_a) * lam + self.criteration(
                                            output, target_b) * (1. - lam)
            elif self.fmix:
                data, target = fmix(img,
                                    label,
                                    alpha=1.,
                                    decay_power=3.,
                                    shape=(width, height))
                targets, shuffled_targets, lam = target
                output = self.model(data)
                loss = self.criteration(
                    output, targets) * lam + self.criteration(
                        output, shuffled_targets) * (1 - lam)
            else:
                output = self.model(img)
                loss = self.criteration(output, label)
            loss.backward()
            _, preds = torch.max(output.data, 1)
            correct = (preds == label).sum().item()
            train_acc_sum += correct

            train_loss_sum += loss.item()
            self.optimizer.step()

        train_loss = train_loss_sum / len(train_loader.dataset)
        train_acc = train_acc_sum / len(train_loader.dataset)

        val_acc_sum = 0.0
        valid_loss_sum = 0
        self.model.eval()
        for val_img, val_label in val_loader:
            # if len(val_label) <= 1:
            #     continue
            val_img, val_label = val_img.to(DEVICE), val_label.to(DEVICE)
            val_output = self.model(val_img)
            _, preds = torch.max(val_output.data, 1)
            correct = (preds == val_label).sum().item()
            val_acc_sum += correct

            loss = self.criteration(val_output, val_label)
            valid_loss_sum += loss.item()

        val_acc = val_acc_sum / len(val_loader.dataset)
        val_loss = valid_loss_sum / len(val_loader.dataset)
        return train_loss, train_acc, val_loss, val_acc

    def train(self):
        '''
        训练模型,必须实现此方法
        :return:
        '''
        # pass
        df = pd.read_csv(os.path.join(DATA_PATH, DataID, 'train.csv'))

        kf = KFold(n_splits=5, shuffle=False, random_state=42)
        for fold, (train_idx, val_idx) in enumerate(kf.split(df)):
            # # abandon cross validation
            # if fold > 0:
            #     break
            self.__init__(self.model_name)
            print(
                f'fold:{fold+1}...', 'train_size: %d, val_size: %d' %
                (len(train_idx), len(val_idx)))

            # generate dataloder
            train_data = ImageData(df, train_idx, mode='train')
            val_data = ImageData(df, val_idx, mode='valid')
            train_loader = DataLoader(
                train_data,
                batch_size=args.BATCH,
                shuffle=True,
                # drop_last=True
            )
            val_loader = DataLoader(val_data,
                                    batch_size=args.BATCH,
                                    shuffle=False,
                                    drop_last=True)

            max_correct = 0
            for epoch in range(args.EPOCHS):
                self.scheduler.step(epoch)
                train_loss, train_acc, val_loss, val_acc = self.train_one_epoch(
                    train_loader, val_loader)
                start = time.strftime("%H:%M:%S")
                print(f'fold:{fold + 1}',
                      f"epoch:{epoch + 1}/{args.EPOCHS} | ⏰: {start}   ",
                      f"Training Loss: {train_loss:.6f}.. ",
                      f"Training Acc:  {train_acc:.6f}.. ",
                      f"validation Acc: {val_acc:.6f}.. ")

                train_log(train_loss=train_loss,
                          train_acc=train_acc,
                          val_loss=val_loss,
                          val_acc=val_acc)

                if val_acc > max_correct:
                    max_correct = val_acc
                    torch.save(
                        self.model, MODEL_PATH + '/' +
                        f"{self.model_name}_best_fold{fold+1}.pth")
                    # torch.save(self.model, MODEL_PATH + '/' + "best.pth")
                    print('find optimal model')
Пример #9
0
def train_ccblock(model_options):
    # get train&valid datasets' paths
    if model_options.trainset_num > 1:
        train_file_paths = [
            model_options.trainset_path.format(i)
            for i in range(1, model_options.trainset_num + 1)
        ]
    else:
        train_file_paths = [model_options.trainset_path]

    # load datasets
    print(train_file_paths)
    label_paths = "/home/langruimin/BLSTM_pytorch/data/fcv/fcv_train_labels.mat"
    videoset = VideoDataset(train_file_paths, label_paths)
    print(len(videoset))

    # create model
    model = RCCAModule(1, 1, recurrence=2)

    model_quan = Quantization(12, 1024, model_options.dim)

    params_path = os.path.join(model_options.model_save_path,
                               model_options.params_filename)
    params_path_Q = os.path.join(model_options.model_save_path,
                                 model_options.Qparams_filename)
    if model_options.reload_params:
        print('Loading model params...')
        model.load_state_dict(torch.load(params_path))
        print('Done.')

    model = model.cuda()
    model_quan = model_quan.cuda()
    # optimizer
    optimizer = RAdam(model.parameters(),
                      lr=1e-4,
                      betas=(0.9, 0.999),
                      weight_decay=1e-4)
    optimizer2 = RAdam(model_quan.parameters(),
                       lr=1e-3,
                       betas=(0.9, 0.999),
                       weight_decay=1e-4)

    lr_C = ''
    lr_Q = ''
    milestones = []
    # lr_schduler_C = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)
    lr_schduler_Q = torch.optim.lr_scheduler.MultiStepLR(optimizer2,
                                                         milestones,
                                                         gamma=0.6,
                                                         last_epoch=-1)

    # triplet selector, triplet loss
    # selector = BatchHardTripletSelector()
    # triplet_loss = TripletLoss(margin=0.1)
    # selector = HardestNegativeTripletSelector(margin=0.1, cpu=False)
    '''
    selector = AllTripletSelector()
    triplet_loss = OnlineTripletLoss(margin=512, triplet_selector=selector)
    '''

    # corss_entroypyloss
    # criterion = nn.CrossEntropyLoss()

    # centers = np.load(options.centers_path)
    # centers = torch.Tensor(centers).cuda()

    # neighborLoss
    # '''
    # load the similarity matrix
    print("+++++++++loading similarity+++++++++")
    f = open(
        "/home/langruimin/BLSTM_pytorch/data/fcv/SimilarityInfo/Sim_K1_10_K2_5_fcv.pkl",
        "rb")
    similarity = pkl.load(f)
    similarity = torch.ByteTensor(similarity.astype(np.uint8))
    f.close()
    print("++++++++++similarity loaded+++++++")
    # '''

    batch_idx = 1
    train_loss_rec = open(
        os.path.join(model_options.records_save_path,
                     model_options.train_loss_filename), 'w')
    error_ = 0.
    loss_ = 0.
    num = 0
    neighbor = True
    neighbor_freq = 2
    print("##########start train############")
    trainloader = torch.utils.data.DataLoader(
        videoset,
        batch_size=model_options.train_batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True)
    model.train()
    model_quan.train()

    neighbor_loss = 0.0
    for l in range(60):
        # lr_schduler_C.step(l)
        # milestones.append(l+2)
        # lr_schduler_Q.step(l)

        if neighbor == True:
            # training
            for i, (data, index, _, _) in enumerate(trainloader):
                data = data.to(model_options.default_dtype)
                data = data.unsqueeze(1)
                data = data.cuda()
                # if l > 0:
                #     print("data shape: ",data.shape)
                # cc_block
                output_ccblock_mean = torch.tanh(model(data))

                # triplet_loss
                # tri_loss, tri_num = triplet_loss(output_ccblock_mean, init_train_label[index])
                # print("triplets num: ",tri_num)

                # cross_entropy_loss
                # loss = criterion(output_crossEntropy,labels.cuda())

                # cluster_loss
                # center_loss, _ = cluster_loss(centers, output_ccblock_mean, init_train_label[index], margin=0.5)

                # quantization block
                Qhard, Qsoft, SoftDistortion, HardDistortion, JointCenter, error, _ = model_quan(
                    output_ccblock_mean)
                Q_loss = 0.1 * SoftDistortion + HardDistortion + 0.1 * JointCenter

                optimizer2.zero_grad()
                Q_loss.backward(retain_graph=True)
                optimizer2.step()

                if l % neighbor_freq == 0:
                    # neighbor loss
                    similarity_select = torch.index_select(
                        similarity, 0, index)
                    similarity_select = torch.index_select(
                        similarity_select, 1, index).float().cuda()
                    neighbor_loss = torch.sum(
                        (torch.mm(output_ccblock_mean,
                                  output_ccblock_mean.transpose(0, 1)) / 1024 -
                         similarity_select).pow(2))

                    optimizer.zero_grad()
                    neighbor_loss.backward()
                    optimizer.step()

                error_ += error.item()
                loss_ += neighbor_loss.item()
                num += 1
                if batch_idx % model_options.disp_freq == 0:
                    info = "epoch{0} Batch {1} loss:{2:.3f}  distortion:{3:.3f} " \
                        .format(l, batch_idx, loss_/ num, error_ / num)
                    print(info)
                    train_loss_rec.write(info + '\n')

                batch_idx += 1
            batch_idx = 0
            error_ = 0.
            loss_ = 0.
            num = 0

        if (l + 1) % model_options.save_freq == 0:
            print('epoch: ', l, 'New best model. Saving model ...')
            torch.save(model.state_dict(), params_path)
            torch.save(model_quan.state_dict(), params_path_Q)

            for param_group in optimizer.param_groups:
                lr_C = param_group['lr']
            for param_group in optimizer2.param_groups:
                lr_Q = param_group['lr']
            record_inf = "saved model at epoch {0} lr_C:{1} lr_Q:{2}".format(
                l, lr_C, lr_Q)
            train_loss_rec.write(record_inf + '\n')
        print("##########epoch done##########")

    print('train done. Saving model ...')
    torch.save(model.state_dict(), params_path)
    torch.save(model_quan.state_dict(), params_path_Q)
    print("##########train done##########")
Пример #10
0
class TD3(object):
	def __init__(self, state_dim, action_dim, max_action):
		self.actor = Actor(state_dim, action_dim, max_action).to(device)
		self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
		self.actor_target.load_state_dict(self.actor.state_dict())
		self.actor_optimizer = RAdam(self.actor.parameters())

		self.critic = Critic(state_dim, action_dim).to(device)
		self.critic_target = Critic(state_dim, action_dim).to(device)
		self.critic_target.load_state_dict(self.critic.state_dict())
		self.critic_optimizer = RAdam(self.critic.parameters())

		self.max_action = max_action


	def select_action(self, state):
		state = torch.FloatTensor(state.reshape(1, -1)).to(device)
		return self.actor(state).cpu().data.numpy().flatten()


	def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2):

		for it in range(iterations):

			# Sample replay buffer 
			x, y, u, r, d = replay_buffer.sample(batch_size)
			state = torch.FloatTensor(x).to(device)
			action = torch.FloatTensor(u).to(device)
			next_state = torch.FloatTensor(y).to(device)
			done = torch.FloatTensor(1 - d).to(device)
			reward = torch.FloatTensor(r).to(device)

			# Select action according to policy and add clipped noise 
			noise = torch.FloatTensor(u).data.normal_(0, policy_noise).to(device)
			noise = noise.clamp(-noise_clip, noise_clip)
			next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)

			# Compute the target Q value
			target_Q1, target_Q2 = self.critic_target(next_state, next_action)
			target_Q = torch.min(target_Q1, target_Q2)
			target_Q = reward + (done * discount * target_Q).detach()

			# Get current Q estimates
			current_Q1, current_Q2 = self.critic(state, action)

			# Compute critic loss
			critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 

			# Optimize the critic
			self.critic_optimizer.zero_grad()
			critic_loss.backward()
			self.critic_optimizer.step()

			# Delayed policy updates
			if it % policy_freq == 0:

				# Compute actor loss
				actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
				
				# Optimize the actor 
				self.actor_optimizer.zero_grad()
				actor_loss.backward()
				self.actor_optimizer.step()

				# Update the frozen target models
				for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
					target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

				for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
					target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)


	def save(self, filename, directory):
		torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, filename))
		torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename))


	def load(self, filename, directory):
		self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, filename)))
		self.critic.load_state_dict(torch.load('%s/%s_critic.pth' % (directory, filename)))
Пример #11
0
class Trainer(object):
    """
    """
    def __init__(self):
        torch.set_num_threads(4)
        self.n_epochs = 10
        self.batch_size = 1
        self.patch_size = 384
        self.is_augment = False
        self.cuda = torch.cuda.is_available()
        self.__build_model()

    def __build_model(self):
        self.model = UNet(1, 1, base=16)
        if self.cuda:
            self.model = self.model.cuda()

    def __reshapetensor(self, tensor, itype='image'):
        if itype == 'image':
            d0, d1, d2, d3, d4 = tensor.size()
            tensor = tensor.view(d0 * d1, d2, d3, d4)
        else:
            d0, d1, d2, d3 = tensor.size()
            tensor = tensor.view(d0 * d1, d2, d3)

        return tensor

    def __get_optimizer(self, **params):
        opt_params = {
            'params': self.model.parameters(),
            'lr': 1e-2,
            'weight_decay': 1e-5
        }
        self.optimizer = RAdam(**opt_params)

        # self.scheduler = None
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                              'max',
                                                              factor=0.5,
                                                              patience=10,
                                                              verbose=True,
                                                              min_lr=1e-5)

    def run(self, trainset, model_dir):
        """
        """
        print('=' * 100)
        print('Trainning model')
        print('=' * 100)
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)

        model_path = os.path.join(model_dir, 'model.pth')

        #loss_fn = DiceLoss()
        loss_fn = FocalLoss2d()
        #loss_fn = CombineLoss({'dice':0.5, 'focal':0.5})

        self.__get_optimizer()
        Loss = []
        F1 = []
        for epoch in range(self.n_epochs):

            for ith_batch, data in enumerate(trainset):
                images, labels = [d.cuda()
                                  for d in data] if self.cuda else data
                images = self.__reshapetensor(images, itype='image')
                labels = self.__reshapetensor(labels, itype='label')

                preds = self.model(images)
                loss = loss_fn(preds, labels)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                Loss.append(loss.item())
                preds = torch.sigmoid(preds)
                preds[preds > 0.5] = 1
                preds[preds <= 0.5] = 0
                preds = preds.cpu().detach().numpy().flatten()
                labels = labels.cpu().detach().numpy().flatten()
                f1 = f1_score(labels, preds, average='binary')
                F1.append(f1)

                print('EPOCH : {}-----BATCH : {}-----LOSS : {}-----F1 : {}'.
                      format(epoch, ith_batch, loss.item(), f1))

        torch.save(self.model.state_dict(), model_path)

        return model_path
Пример #12
0
class face_learner(object):
    def __init__(self, conf):
        print(conf)
        self.model = ResNet()
        self.model.cuda()
        if conf.initial:
            self.model.load_state_dict(torch.load("models/"+conf.model))
            print('Load model_ir_se101.pth')
        self.milestones = conf.milestones
        self.loader, self.class_num = get_train_loader(conf)
        self.total_class = 16520
        self.data_num = 285356
        self.writer = SummaryWriter(conf.log_path)
        self.step = 0
        self.paras_only_bn, self.paras_wo_bn = separate_bn_paras(self.model)

        if conf.meta:
            self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.total_class)
            self.head.cuda()
            if conf.initial:
                self.head.load_state_dict(torch.load("models/head_op.pth"))
                print('Load head_op.pth')
            self.optimizer = RAdam([
                {'params': self.paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4},
                {'params': self.paras_only_bn}
            ], lr=conf.lr)
            self.meta_optimizer = RAdam([
                {'params': self.paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4},
                {'params': self.paras_only_bn}
            ], lr=conf.lr)
            self.head.train()
        else:
            self.head = dict()
            self.optimizer = dict()
            for race in races:
                self.head[race] = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num[race])
                self.head[race].cuda()
                if conf.initial:
                    self.head[race].load_state_dict(torch.load("models/head_op_{}.pth".format(race)))
                    print('Load head_op_{}.pth'.format(race))
                self.optimizer[race] = RAdam([
                    {'params': self.paras_wo_bn + [self.head[race].kernel], 'weight_decay': 5e-4},
                    {'params': self.paras_only_bn}
                ], lr=conf.lr, betas=(0.5, 0.999))
                self.head[race].train()
            # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)

        self.board_loss_every = min(len(self.loader[race]) for race in races) // 10
        self.evaluate_every = self.data_num // 5
        self.save_every = self.data_num // 2
        self.eval, self.eval_issame = get_val_data(conf)

    def save_state(self, conf, accuracy, extra=None, model_only=False, race='All'):
        save_path = 'models/'
        torch.save(
            self.model.state_dict(), save_path +
                                     'model_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy, self.step,
                                                                                  extra, race))
        if not model_only:
            if conf.meta:
                torch.save(
                    self.head.state_dict(), save_path +
                                        'head_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy, self.step,
                                                                                    extra, race))
                #torch.save(
                #    self.optimizer.state_dict(), save_path +
                #                             'optimizer_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy,
                #                                                                              self.step, extra, race))
            else:
                torch.save(
                    self.head[race].state_dict(), save_path +
                                            'head_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy,
                                                                                           self.step,
                                                                                           extra, race))
                #torch.save(
                #    self.optimizer[race].state_dict(), save_path +
                 #                                'optimizer_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(),
                #                                                                                     accuracy,
                #                                                                                     self.step, extra,
                #                                                                                     race))

    def load_state(self, conf, fixed_str, model_only=False):
        save_path = 'models/'
        self.model.load_state_dict(torch.load(save_path + conf.model))
        if not model_only:
            self.head.load_state_dict(torch.load(save_path + conf.head))
            self.optimizer.load_state_dict(torch.load(save_path + conf.optim))

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step)

        # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step)
        # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step)
        # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step)

    def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False):
        self.model.eval()
        idx = 0
        entry_num = carray.size()[0]
        embeddings = np.zeros([entry_num, conf.embedding_size])
        with torch.no_grad():
            while idx + conf.batch_size <= entry_num:
                batch = carray[idx:idx + conf.batch_size]
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.cuda()) + self.model(fliped.cuda())
                    embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu().detach().numpy()
                else:
                    embeddings[idx:idx + conf.batch_size] = self.model(batch.cuda()).cpu().detach().numpy()
                idx += conf.batch_size
            if idx < entry_num:
                batch = carray[idx:]
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.cuda()) + self.model(fliped.cuda())
                    embeddings[idx:] = l2_norm(emb_batch).cpu().detach().numpy()
                else:
                    embeddings[idx:] = self.model(batch.cuda()).cpu().detach().numpy()
        tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = trans.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor

    def train_finetuning(self, conf, epochs, race):
        self.model.train()
        running_loss = 0.
        for e in range(epochs):
            print('epoch {} started'.format(e))
            '''
            if e == self.milestones[0]:
                for ra in races:
                    for params in self.optimizer[ra].param_groups:
                        params['lr'] /= 10
            if e == self.milestones[1]:
                for ra in races:
                    for params in self.optimizer[ra].param_groups:
                        params['lr'] /= 10
            if e == self.milestones[2]:
                for ra in races:
                    for params in self.optimizer[ra].param_groups:
                        params['lr'] /= 10
            '''
            for imgs, labels in tqdm(iter(self.loader[race])):
                imgs = imgs.cuda()
                labels = labels.cuda()
                self.optimizer[race].zero_grad()
                embeddings = self.model(imgs)
                thetas = self.head[race](embeddings, labels)
                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm)
                nn.utils.clip_grad_norm_(self.head[race].parameters(), conf.max_grad_norm)
                self.optimizer[race].step()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                if self.step % (1 * len(self.loader[race])) == 0 and self.step != 0:
                    self.save_state(conf, 'None', race=race, model_only=True)

                self.step += 1

        self.save_state(conf, 'None', extra='final', race=race)
        torch.save(self.optimizer[race].state_dict(), 'models/optimizer_{}.pth'.format(race))

    def train_maml(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        loader_iter = dict()
        for race in races:
            loader_iter[race] = iter(self.loader[race])
        for e in range(epochs):
            print('epoch {} started'.format(e))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            for i in tqdm(range(self.data_num // conf.batch_size)):
                ra1, ra2 = random.sample(races, 2)
                try:
                    imgs1, labels1 = loader_iter[ra1].next()
                except StopIteration:
                    loader_iter[ra1] = iter(self.loader[ra1])
                    imgs1, labels1 = loader_iter[ra1].next()

                try:
                    imgs2, labels2 = loader_iter[ra2].next()
                except StopIteration:
                    loader_iter[ra2] = iter(self.loader[ra2])
                    imgs2, labels2 = loader_iter[ra2].next()

                ## save original weights to make the update
                weights_original_model = deepcopy(self.model.state_dict())
                weights_original_head = deepcopy(self.head.state_dict())

                # base learn
                imgs1 = imgs1.cuda()
                labels1 = labels1.cuda()
                self.optimizer.zero_grad()
                embeddings1 = self.model(imgs1)
                thetas1 = self.head(embeddings1, labels1)
                loss1 = conf.ce_loss(thetas1, labels1)
                loss1.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm)
                nn.utils.clip_grad_norm_(self.head.parameters(), conf.max_grad_norm)
                self.optimizer.step()

                # meta learn
                imgs2 = imgs2.cuda()
                labels2 = labels2.cuda()
                embeddings2 = self.model(imgs2)
                thetas2 = self.head(embeddings2, labels2)
                self.model.load_state_dict(weights_original_model)
                self.head.load_state_dict(weights_original_head)
                self.meta_optimizer.zero_grad()
                loss2 = conf.ce_loss(thetas2, labels2)
                loss2.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm)
                nn.utils.clip_grad_norm_(self.head.parameters(), conf.max_grad_norm)
                self.meta_optimizer.step()

                running_loss += loss2.item()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    for race in races:
                        accuracy, best_threshold, roc_curve_tensor = self.evaluate(conf, self.eval[race], self.eval_issame[race])
                        self.board_val(race, accuracy, best_threshold, roc_curve_tensor)
                    self.model.train()

                if self.step % (self.data_num // conf.batch_size // 2) == 0 and self.step != 0:
                    self.save_state(conf, e)

                self.step += 1

        self.save_state(conf, epochs, extra='final')

    def train_meta_head(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        optimizer = optim.SGD(self.head.parameters(), lr=conf.lr, momentum=conf.momentum)
        for e in range(epochs):
            print('epoch {} started'.format(e))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            for race in races:
                for imgs, labels in tqdm(iter(self.loader[race])):
                    imgs = imgs.cuda()
                    labels = labels.cuda()
                    optimizer.zero_grad()
                    embeddings = self.model(imgs)
                    thetas = self.head(embeddings, labels)
                    loss = conf.ce_loss(thetas, labels)
                    loss.backward()
                    running_loss += loss.item()
                    optimizer.step()

                    if self.step % self.board_loss_every == 0 and self.step != 0:
                        loss_board = running_loss / self.board_loss_every
                        self.writer.add_scalar('train_loss', loss_board, self.step)
                        running_loss = 0.

                    self.step += 1

            torch.save(self.head.state_dict(), 'models/head_{}_meta_{}.pth'.format(get_time(), e))

    def train_race_head(self, conf, epochs, race):
        self.model.train()
        running_loss = 0.
        optimizer = optim.SGD(self.head[race].parameters(), lr=conf.lr, momentum=conf.momentum)
        for e in range(epochs):
            print('epoch {} started'.format(e))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            for imgs, labels in tqdm(iter(self.loader[race])):
                imgs = imgs.cuda()
                labels = labels.cuda()
                optimizer.zero_grad()
                embeddings = self.model(imgs)
                thetas = self.head[race](embeddings, labels)
                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                optimizer.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                self.step += 1

        torch.save(self.head[race].state_dict(), 'models/head_{}_{}_{}.pth'.format(get_time(), race, epochs))

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        for params in self.meta_optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer, self.meta_optimizer)
def main():
    parser = argparse.ArgumentParser()
    # path setting
    parser.add_argument("--waveforms",
                        type=str,
                        help="directory or list of wav files")
    parser.add_argument("--waveforms_eval",
                        type=str,
                        help="directory or list of evaluation wav files")
    parser.add_argument("--feats",
                        required=True,
                        type=str,
                        help="directory or list of wav files")
    parser.add_argument("--feats_eval",
                        required=True,
                        type=str,
                        help="directory or list of evaluation feat files")
    parser.add_argument("--stats",
                        required=True,
                        type=str,
                        help="directory or list of evaluation wav files")
    parser.add_argument("--expdir",
                        required=True,
                        type=str,
                        help="directory to save the model")
    # network structure setting
    parser.add_argument("--upsampling_factor",
                        default=120,
                        type=int,
                        help="number of dimension of aux feats")
    parser.add_argument("--hid_chn",
                        default=256,
                        type=int,
                        help="depth of dilation")
    parser.add_argument("--skip_chn",
                        default=256,
                        type=int,
                        help="depth of dilation")
    parser.add_argument("--dilation_depth",
                        default=3,
                        type=int,
                        help="depth of dilation")
    parser.add_argument("--dilation_repeat",
                        default=2,
                        type=int,
                        help="depth of dilation")
    parser.add_argument("--kernel_size",
                        default=7,
                        type=int,
                        help="kernel size of dilated causal convolution")
    parser.add_argument("--kernel_size_wave",
                        default=7,
                        type=int,
                        help="kernel size of dilated causal convolution")
    parser.add_argument("--dilation_size_wave",
                        default=1,
                        type=int,
                        help="kernel size of dilated causal convolution")
    parser.add_argument("--mcep_dim",
                        default=50,
                        type=int,
                        help="kernel size of dilated causal convolution")
    # network training setting
    parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
    parser.add_argument(
        "--batch_size",
        default=30,
        type=int,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--epoch_count",
                        default=4000,
                        type=int,
                        help="number of training epochs")
    parser.add_argument("--do_prob",
                        default=0,
                        type=float,
                        help="dropout probability")
    parser.add_argument(
        "--batch_size_utt",
        default=5,
        type=int,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument(
        "--batch_size_utt_eval",
        default=5,
        type=int,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument(
        "--n_workers",
        default=2,
        type=int,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument(
        "--n_quantize",
        default=256,
        type=int,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument(
        "--bi_wave",
        default=True,
        type=strtobool,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument(
        "--causal_conv_wave",
        default=False,
        type=strtobool,
        help="batch size (if set 0, utterance batch will be used)")
    # other setting
    parser.add_argument("--init",
                        default=False,
                        type=strtobool,
                        help="seed number")
    parser.add_argument("--pad_len",
                        default=3000,
                        type=int,
                        help="seed number")
    ##parser.add_argument("--save_interval_iter", default=5000,
    #parser.add_argument("--save_interval_iter", default=3000,
    #                    type=int, help="interval steps to logr")
    parser.add_argument("--save_interval_epoch",
                        default=10,
                        type=int,
                        help="interval steps to logr")
    parser.add_argument("--log_interval_steps",
                        default=50,
                        type=int,
                        help="interval steps to logr")
    parser.add_argument("--seed", default=1, type=int, help="seed number")
    parser.add_argument("--resume",
                        default=None,
                        type=str,
                        help="model path to restart training")
    parser.add_argument("--pretrained",
                        default=None,
                        type=str,
                        help="model path to restart training")
    parser.add_argument("--preconf",
                        default=None,
                        type=str,
                        help="model path to restart training")
    parser.add_argument("--string_path",
                        default=None,
                        type=str,
                        help="model path to restart training")
    parser.add_argument("--GPU_device",
                        default=None,
                        type=int,
                        help="selection of GPU device")
    parser.add_argument("--verbose", default=1, type=int, help="log level")
    args = parser.parse_args()

    if args.GPU_device is not None:
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU_device)

    # make experimental directory
    if not os.path.exists(args.expdir):
        os.makedirs(args.expdir)

    # set log level
    if args.verbose == 1:
        logging.basicConfig(
            level=logging.INFO,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    elif args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
        logging.warn("logging is disabled.")

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if str(device) == "cpu":
        raise ValueError('ERROR: Training by CPU is not acceptable.')

    torch.backends.cudnn.benchmark = True  #faster

    if args.pretrained is None:
        if 'mel' in args.string_path:
            mean_stats = torch.FloatTensor(read_hdf5(args.stats,
                                                     "/mean_melsp"))
            scale_stats = torch.FloatTensor(
                read_hdf5(args.stats, "/scale_melsp"))
            args.excit_dim = 0
        else:
            mean_stats = torch.FloatTensor(
                read_hdf5(args.stats, "/mean_feat_mceplf0cap"))
            scale_stats = torch.FloatTensor(
                read_hdf5(args.stats, "/scale_feat_mceplf0cap"))
            args.cap_dim = mean_stats.shape[0] - (args.mcep_dim + 3)
            args.excit_dim = 2 + 1 + args.cap_dim
    else:
        config = torch.load(args.preconf)
        args.excit_dim = config.excit_dim
        args.cap_dim = config.cap_dim

    # save args as conf
    torch.save(args, args.expdir + "/model.conf")

    # define network
    model_waveform = DSWNV(n_aux=args.mcep_dim + args.excit_dim,
                           upsampling_factor=args.upsampling_factor,
                           hid_chn=args.hid_chn,
                           skip_chn=args.skip_chn,
                           kernel_size=args.kernel_size,
                           aux_kernel_size=args.kernel_size_wave,
                           aux_dilation_size=args.dilation_size_wave,
                           dilation_depth=args.dilation_depth,
                           dilation_repeat=args.dilation_repeat,
                           n_quantize=args.n_quantize,
                           do_prob=args.do_prob)
    logging.info(model_waveform)
    shift_rec_field = model_waveform.receptive_field
    logging.info(shift_rec_field)
    if shift_rec_field % args.upsampling_factor > 0:
        shift_rec_field_frm = shift_rec_field // args.upsampling_factor + 1
    else:
        shift_rec_field_frm = shift_rec_field // args.upsampling_factor
    shift_rec_field = shift_rec_field_frm * args.upsampling_factor
    logging.info(shift_rec_field)
    logging.info(shift_rec_field_frm)
    criterion_ce = torch.nn.CrossEntropyLoss(reduction='none')
    criterion_l1 = torch.nn.L1Loss(reduction='none')

    # send to gpu
    if torch.cuda.is_available():
        model_waveform.cuda()
        criterion_ce.cuda()
        criterion_l1.cuda()
        if args.pretrained is None:
            mean_stats = mean_stats.cuda()
            scale_stats = scale_stats.cuda()
    else:
        logging.error("gpu is not available. please check the setting.")
        sys.exit(1)

    model_waveform.train()

    if args.pretrained is None:
        model_waveform.scale_in.weight = torch.nn.Parameter(
            torch.unsqueeze(torch.diag(1.0 / scale_stats.data), 2))
        model_waveform.scale_in.bias = torch.nn.Parameter(-(mean_stats.data /
                                                            scale_stats.data))

    #if args.pretrained is not None:
    #    checkpoint = torch.load(args.pretrained)
    #    #model_waveform.remove_weight_norm()
    #    #model_waveform.load_state_dict(checkpoint["model"])
    #    model_waveform.load_state_dict(checkpoint["model_waveform"])
    #    epoch_idx = checkpoint["iterations"]
    #    logging.info("pretrained from %d-iter checkpoint." % epoch_idx)
    #    epoch_idx = 0
    #    #model_waveform.apply_weight_norm()
    #    #torch.nn.utils.remove_weight_norm(model_waveform.scale_in)

    for param in model_waveform.parameters():
        param.requires_grad = True
    for param in model_waveform.scale_in.parameters():
        param.requires_grad = False

    parameters = filter(lambda p: p.requires_grad, model_waveform.parameters())
    parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000
    logging.info('Trainable Parameters (waveform): %.3f million' % parameters)

    module_list = list(model_waveform.conv_aux.parameters()) + list(
        model_waveform.upsampling.parameters())
    if model_waveform.wav_conv_flag:
        module_list += list(model_waveform.wav_conv.parameters())
    module_list += list(model_waveform.causal.parameters())
    module_list += list(model_waveform.in_x.parameters()) + list(
        model_waveform.dil_h.parameters())
    module_list += list(model_waveform.out_skip.parameters())
    module_list += list(model_waveform.out_1.parameters()) + list(
        model_waveform.out_2.parameters())

    optimizer = RAdam(module_list, lr=args.lr)
    #optimizer = torch.optim.Adam(module_list, lr=args.lr)

    # resume
    if args.pretrained is not None and args.resume is None:
        checkpoint = torch.load(args.pretrained)
        model_waveform.load_state_dict(checkpoint["model_waveform"])
        #    optimizer.load_state_dict(checkpoint["optimizer"])
        epoch_idx = checkpoint["iterations"]
        logging.info("pretrained from %d-iter checkpoint." % epoch_idx)
        epoch_idx = 0
    elif args.resume is not None:
        #if args.resume is not None:
        checkpoint = torch.load(args.resume)
        model_waveform.load_state_dict(checkpoint["model_waveform"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epoch_idx = checkpoint["iterations"]
        logging.info("restored from %d-iter checkpoint." % epoch_idx)
    #    epoch_idx = 2
    else:
        epoch_idx = 0

    def zero_wav_pad(x):
        return padding(x, args.pad_len * args.upsampling_factor,
                       value=0.0)  # noqa: E704

    def zero_feat_pad(x):
        return padding(x, args.pad_len, value=0.0)  # noqa: E704

    pad_wav_transform = transforms.Compose([zero_wav_pad])
    pad_feat_transform = transforms.Compose([zero_feat_pad])

    wav_transform = transforms.Compose(
        [lambda x: encode_mu_law(x, args.n_quantize)])

    # define generator training
    if os.path.isdir(args.waveforms):
        filenames = sorted(
            find_files(args.waveforms, "*.wav", use_dir_name=False))
        wav_list = [args.waveforms + "/" + filename for filename in filenames]
    elif os.path.isfile(args.waveforms):
        wav_list = read_txt(args.waveforms)
    else:
        logging.error("--waveforms should be directory or list.")
        sys.exit(1)
    if os.path.isdir(args.feats):
        feat_list = [args.feats + "/" + filename for filename in filenames]
    elif os.path.isfile(args.feats):
        feat_list = read_txt(args.feats)
    else:
        logging.error("--feats should be directory or list.")
        sys.exit(1)
    assert len(wav_list) == len(feat_list)
    logging.info("number of training data = %d." % len(feat_list))
    dataset = FeatureDatasetNeuVoco(wav_list,
                                    feat_list,
                                    pad_wav_transform,
                                    pad_feat_transform,
                                    args.upsampling_factor,
                                    args.string_path,
                                    wav_transform=wav_transform)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size_utt,
                            shuffle=True,
                            num_workers=args.n_workers)
    #generator = train_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=1)
    generator = train_generator(dataloader,
                                device,
                                args.batch_size,
                                args.upsampling_factor,
                                limit_count=None)
    #generator = train_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=1, resume_c_idx=1426, max_c_idx=(len(feat_list)//args.batch_size_utt))
    #generator = train_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=None, resume_c_idx=1426, max_c_idx=(len(feat_list)//args.batch_size_utt))

    # define generator evaluation
    if os.path.isdir(args.waveforms_eval):
        filenames = sorted(
            find_files(args.waveforms_eval, "*.wav", use_dir_name=False))
        wav_list_eval = [
            args.waveforms + "/" + filename for filename in filenames
        ]
    elif os.path.isfile(args.waveforms_eval):
        wav_list_eval = read_txt(args.waveforms_eval)
    else:
        logging.error("--waveforms_eval should be directory or list.")
        sys.exit(1)
    if os.path.isdir(args.feats_eval):
        feat_list_eval = [
            args.feats_eval + "/" + filename for filename in filenames
        ]
    elif os.path.isfile(args.feats):
        feat_list_eval = read_txt(args.feats_eval)
    else:
        logging.error("--feats_eval should be directory or list.")
        sys.exit(1)
    assert len(wav_list_eval) == len(feat_list_eval)
    logging.info("number of evaluation data = %d." % len(feat_list_eval))
    dataset_eval = FeatureDatasetNeuVoco(wav_list_eval,
                                         feat_list_eval,
                                         pad_wav_transform,
                                         pad_feat_transform,
                                         args.upsampling_factor,
                                         args.string_path,
                                         wav_transform=wav_transform)
    dataloader_eval = DataLoader(dataset_eval,
                                 batch_size=args.batch_size_utt_eval,
                                 shuffle=False,
                                 num_workers=args.n_workers)
    ##generator_eval = eval_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=1)
    #generator_eval = eval_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=None)
    #generator_eval = train_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=1)
    generator_eval = train_generator(dataloader_eval,
                                     device,
                                     args.batch_size,
                                     args.upsampling_factor,
                                     limit_count=None)

    writer = SummaryWriter(args.expdir)
    total_train_loss = defaultdict(list)
    total_eval_loss = defaultdict(list)

    # train
    logging.info(args.string_path)
    total = 0
    iter_count = 0
    loss_ce = []
    loss_err = []
    min_eval_loss_err = 99999999.99
    min_eval_loss_err_std = 99999999.99
    min_eval_loss_ce = 99999999.99
    min_eval_loss_ce_std = 99999999.99
    iter_idx = 0
    min_idx = -1
    #min_eval_loss_ce = 1.575400
    #min_eval_loss_ce_std = 0.645726
    #iter_idx = 8098898
    #min_idx = 68 #resume70
    change_min_flag = False
    if args.resume is not None:
        np.random.set_state(checkpoint["numpy_random_state"])
        torch.set_rng_state(checkpoint["torch_random_state"])
    logging.info("==%d EPOCH==" % (epoch_idx + 1))
    logging.info("Training data")
    while epoch_idx < args.epoch_count:
        start = time.time()
        batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \
            del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator)
        if args.init:
            c_idx = -1
        if c_idx < 0:  # summarize epoch
            if not args.init:
                # save current epoch model
                numpy_random_state = np.random.get_state()
                torch_random_state = torch.get_rng_state()
                # report current epoch
                logging.info("(EPOCH:%d) average optimization loss = %.6f (+- %.6f) %.6f (+- %.6f) %% ;; "\
                    "(%.3f min., %.3f sec / batch)" % (epoch_idx + 1, np.mean(loss_ce), np.std(loss_ce), \
                        np.mean(loss_err), np.std(loss_err), total / 60.0, total / iter_count))
                logging.info("estimated time until max. epoch = {0.days:02}:{0.hours:02}:{0.minutes:02}:"\
                "{0.seconds:02}".format(relativedelta(seconds=int((args.epoch_count - (epoch_idx + 1)) * total))))
            # compute loss in evaluation data
            total = 0
            iter_count = 0
            loss_ce = []
            loss_err = []
            model_waveform.eval()
            for param in model_waveform.parameters():
                param.requires_grad = False
            pair_exist = False
            logging.info("Evaluation data")
            while True:
                with torch.no_grad():
                    start = time.time()
                    batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \
                        del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator_eval)
                    if c_idx < 0:
                        break

                    x_es = x_ss + x_bs
                    f_es = f_ss + f_bs
                    logging.info(
                        f'{x_ss} {x_bs} {x_es} {f_ss} {f_bs} {f_es} {max_slen}'
                    )
                    if x_ss > 0:
                        if x_es <= max_slen:
                            batch_x_prev = batch_x[:, x_ss - shift_rec_field -
                                                   1:x_es - 1]
                            batch_feat = batch_feat[:, f_ss -
                                                    shift_rec_field_frm:f_es]
                            batch_x = batch_x[:, x_ss:x_es]
                        else:
                            batch_x_prev = batch_x[:, x_ss - shift_rec_field -
                                                   1:-1]
                            batch_feat = batch_feat[:, f_ss -
                                                    shift_rec_field_frm:]
                            batch_x = batch_x[:, x_ss:]
                    #    assert((batch_x_prev[:,shift_rec_field+1:] == batch_x[:,:-1]).all())
                    else:
                        batch_x_prev = F.pad(
                            batch_x[:, :x_es - 1],
                            (model_waveform.receptive_field + 1, 0),
                            "constant", args.n_quantize // 2)
                        batch_feat = batch_feat[:, :f_es]
                        batch_x = batch_x[:, :x_es]
                    #    assert((batch_x_prev[:,model_waveform.receptive_field+1:] == batch_x[:,:-1]).all())

                    if x_ss > 0:
                        batch_x_output = model_waveform(
                            batch_feat, batch_x_prev)[:, shift_rec_field:]
                    else:
                        batch_x_output = model_waveform(
                            batch_feat, batch_x_prev,
                            first=True)[:, model_waveform.receptive_field:]

                    # samples check
                    i = np.random.randint(0, batch_x_output.shape[0])
                    logging.info("%s" % (os.path.join(
                        os.path.basename(os.path.dirname(featfile[i])),
                        os.path.basename(featfile[i]))))
                    #check_samples = batch_x[i,5:10].long()
                    #logging.info(torch.index_select(F.softmax(batch_x_output[i,5:10], dim=-1), 1, check_samples))
                    #logging.info(check_samples)

                    # handle short ending
                    batch_loss = 0
                    if len(idx_select) > 0:
                        logging.info('len_idx_select: ' + str(len(idx_select)))
                        batch_loss_ce = 0
                        batch_loss_err = 0
                        for j in range(len(idx_select)):
                            k = idx_select[j]
                            slens_utt = slens_acc[k]
                            logging.info('%s %d' % (featfile[k], slens_utt))
                            batch_x_output_k = batch_x_output[k, :slens_utt]
                            batch_x_k = batch_x[k, :slens_utt]
                            batch_loss_ce += torch.mean(
                                criterion_ce(batch_x_output_k, batch_x_k))
                            batch_loss_err += torch.mean(
                                torch.sum(
                                    100 * criterion_l1(
                                        F.softmax(batch_x_output_k, dim=-1),
                                        F.one_hot(batch_x_k,
                                                  num_classes=args.n_quantize).
                                        float()), -1))
                        batch_loss += batch_loss_ce
                        batch_loss_ce /= len(idx_select)
                        batch_loss_err /= len(idx_select)
                        total_eval_loss["eval/loss_ce"].append(
                            batch_loss_ce.item())
                        total_eval_loss["eval/loss_err"].append(
                            batch_loss_err.item())
                        loss_ce.append(batch_loss_ce.item())
                        loss_err.append(batch_loss_err.item())
                        if len(idx_select_full) > 0:
                            logging.info('len_idx_select_full: ' +
                                         str(len(idx_select_full)))
                            batch_x = torch.index_select(
                                batch_x, 0, idx_select_full)
                            batch_x_output = torch.index_select(
                                batch_x_output, 0, idx_select_full)
                        else:
                            logging.info(
                                "batch eval loss select %.3f %.3f (%.3f sec)" %
                                (batch_loss_ce.item(), batch_loss_err.item(),
                                 time.time() - start))
                            iter_count += 1
                            total += time.time() - start
                            continue

                    batch_loss_ce_ = torch.mean(
                        criterion_ce(
                            batch_x_output.reshape(-1, args.n_quantize),
                            batch_x.reshape(-1)).reshape(
                                batch_x_output.shape[0], -1), -1)
                    batch_loss_err_ = torch.mean(
                        torch.sum(
                            100 * criterion_l1(
                                F.softmax(batch_x_output, dim=-1),
                                F.one_hot(
                                    batch_x,
                                    num_classes=args.n_quantize).float()), -1),
                        -1)

                    batch_loss_ce = batch_loss_ce_.mean()
                    batch_loss_err = batch_loss_err_.mean()
                    total_eval_loss["eval/loss_ce"].append(
                        batch_loss_ce.item())
                    total_eval_loss["eval/loss_err"].append(
                        batch_loss_err.item())
                    loss_ce.append(batch_loss_ce.item())
                    loss_err.append(batch_loss_err.item())

                    logging.info("batch eval loss [%d] %d %d %d %d %d : %.3f %.3f %% (%.3f sec)" % (c_idx+1, max_slen, \
                        x_ss, x_bs, f_ss, f_bs, batch_loss_ce.item(), batch_loss_err.item(), time.time() - start))
                    iter_count += 1
                    total += time.time() - start
            logging.info('sme')
            for key in total_eval_loss.keys():
                total_eval_loss[key] = np.mean(total_eval_loss[key])
                logging.info(
                    f"(Steps: {iter_idx}) {key} = {total_eval_loss[key]:.4f}.")
            write_to_tensorboard(writer, iter_idx, total_eval_loss)
            total_eval_loss = defaultdict(list)
            eval_loss_ce = np.mean(loss_ce)
            eval_loss_ce_std = np.std(loss_ce)
            eval_loss_err = np.mean(loss_err)
            eval_loss_err_std = np.std(loss_err)
            logging.info("(EPOCH:%d) average evaluation loss = %.6f (+- %.6f) %.6f (+- %.6f) %% ;; (%.3f min., "\
                "%.3f sec / batch)" % (epoch_idx + 1, eval_loss_ce, eval_loss_ce_std, \
                    eval_loss_err, eval_loss_err_std, total / 60.0, total / iter_count))
            if (eval_loss_ce+eval_loss_ce_std) <= (min_eval_loss_ce+min_eval_loss_ce_std) \
                or (eval_loss_ce <= min_eval_loss_ce):
                min_eval_loss_ce = eval_loss_ce
                min_eval_loss_ce_std = eval_loss_ce_std
                min_eval_loss_err = eval_loss_err
                min_eval_loss_err_std = eval_loss_err_std
                min_idx = epoch_idx
                change_min_flag = True
            #else:
            #    epoch_min_flag = False
            if change_min_flag:
                logging.info("min_eval_loss = %.6f (+- %.6f) %.6f (+- %.6f) %% min_idx=%d" % (min_eval_loss_ce, \
                    min_eval_loss_ce_std, min_eval_loss_err, min_eval_loss_err_std, min_idx+1))
            #if ((epoch_idx + 1) % args.save_interval_epoch == 0) or (epoch_min_flag):
            #    logging.info('save epoch:%d' % (epoch_idx+1))
            #    save_checkpoint(args.expdir, model_waveform, optimizer, numpy_random_state, torch_random_state, epoch_idx + 1)
            if args.init:
                exit()
            logging.info('save epoch:%d' % (epoch_idx + 1))
            save_checkpoint(args.expdir, model_waveform, optimizer,
                            numpy_random_state, torch_random_state,
                            epoch_idx + 1)
            total = 0
            iter_count = 0
            loss_ce = []
            loss_err = []
            epoch_idx += 1
            np.random.set_state(numpy_random_state)
            torch.set_rng_state(torch_random_state)
            model_waveform.train()
            for param in model_waveform.parameters():
                param.requires_grad = True
            for param in model_waveform.scale_in.parameters():
                param.requires_grad = False
            # start next epoch
            if epoch_idx < args.epoch_count:
                start = time.time()
                logging.info("==%d EPOCH==" % (epoch_idx + 1))
                logging.info("Training data")
                batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \
                    del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator)
        # feedforward and backpropagate current batch
        if epoch_idx < args.epoch_count:
            logging.info("%d iteration [%d]" % (iter_idx + 1, epoch_idx + 1))

            x_es = x_ss + x_bs
            f_es = f_ss + f_bs
            logging.info(
                f'{x_ss} {x_bs} {x_es} {f_ss} {f_bs} {f_es} {max_slen}')
            if x_ss > 0:
                if x_es <= max_slen:
                    batch_x_prev = batch_x[:,
                                           x_ss - shift_rec_field - 1:x_es - 1]
                    batch_feat = batch_feat[:, f_ss - shift_rec_field_frm:f_es]
                    batch_x = batch_x[:, x_ss:x_es]
                else:
                    batch_x_prev = batch_x[:, x_ss - shift_rec_field - 1:-1]
                    batch_feat = batch_feat[:, f_ss - shift_rec_field_frm:]
                    batch_x = batch_x[:, x_ss:]
            #    assert((batch_x_prev[:,shift_rec_field+1:] == batch_x[:,:-1]).all())
            else:
                batch_x_prev = F.pad(batch_x[:, :x_es - 1],
                                     (model_waveform.receptive_field + 1, 0),
                                     "constant", args.n_quantize // 2)
                batch_feat = batch_feat[:, :f_es]
                batch_x = batch_x[:, :x_es]
            #    assert((batch_x_prev[:,model_waveform.receptive_field+1:] == batch_x[:,:-1]).all())

            if x_ss > 0:
                batch_x_output = model_waveform(batch_feat,
                                                batch_x_prev,
                                                do=True)[:, shift_rec_field:]
            else:
                batch_x_output = model_waveform(
                    batch_feat, batch_x_prev, first=True,
                    do=True)[:, model_waveform.receptive_field:]

            # samples check
            i = np.random.randint(0, batch_x_output.shape[0])
            logging.info(
                "%s" %
                (os.path.join(os.path.basename(os.path.dirname(featfile[i])),
                              os.path.basename(featfile[i]))))
            #with torch.no_grad():
            #    i = np.random.randint(0, batch_x_output.shape[0])
            #    logging.info("%s" % (os.path.join(os.path.basename(os.path.dirname(featfile[i])),os.path.basename(featfile[i]))))
            #    check_samples = batch_x[i,5:10].long()
            #    logging.info(torch.index_select(F.softmax(batch_x_output[i,5:10], dim=-1), 1, check_samples))
            #    logging.info(check_samples)

            # handle short ending
            batch_loss = 0
            if len(idx_select) > 0:
                logging.info('len_idx_select: ' + str(len(idx_select)))
                batch_loss_ce = 0
                batch_loss_err = 0
                for j in range(len(idx_select)):
                    k = idx_select[j]
                    slens_utt = slens_acc[k]
                    logging.info('%s %d' % (featfile[k], slens_utt))
                    batch_x_output_k = batch_x_output[k, :slens_utt]
                    batch_x_k = batch_x[k, :slens_utt]
                    batch_loss_ce += torch.mean(
                        criterion_ce(batch_x_output_k, batch_x_k))
                    batch_loss_err += torch.mean(
                        torch.sum(
                            100 * criterion_l1(
                                F.softmax(batch_x_output_k, dim=-1),
                                F.one_hot(
                                    batch_x_k,
                                    num_classes=args.n_quantize).float()), -1))
                batch_loss += batch_loss_ce
                batch_loss_ce /= len(idx_select)
                batch_loss_err /= len(idx_select)
                total_train_loss["train/loss_ce"].append(batch_loss_ce.item())
                total_train_loss["train/loss_err"].append(
                    batch_loss_err.item())
                loss_ce.append(batch_loss_ce.item())
                loss_err.append(batch_loss_err.item())
                if len(idx_select_full) > 0:
                    logging.info('len_idx_select_full: ' +
                                 str(len(idx_select_full)))
                    batch_x = torch.index_select(batch_x, 0, idx_select_full)
                    batch_x_output = torch.index_select(
                        batch_x_output, 0, idx_select_full)
                else:
                    optimizer.zero_grad()
                    batch_loss.backward()
                    torch.nn.utils.clip_grad_norm_(model_waveform.parameters(),
                                                   10)
                    optimizer.step()

                    logging.info("batch loss select %.3f %.3f (%.3f sec)" %
                                 (batch_loss_ce.item(), batch_loss_err.item(),
                                  time.time() - start))
                    iter_idx += 1
                    #if iter_idx % args.save_interval_iter == 0:
                    #    logging.info('save iter:%d' % (iter_idx))
                    #    save_checkpoint(args.expdir, model_waveform, optimizer, np.random.get_state(), torch.get_rng_state(), iter_idx)
                    iter_count += 1
                    if iter_idx % args.log_interval_steps == 0:
                        logging.info('smt')
                        for key in total_train_loss.keys():
                            total_train_loss[key] = np.mean(
                                total_train_loss[key])
                            logging.info(
                                f"(Steps: {iter_idx}) {key} = {total_train_loss[key]:.4f}."
                            )
                        write_to_tensorboard(writer, iter_idx,
                                             total_train_loss)
                        total_train_loss = defaultdict(list)
                    total += time.time() - start
                    continue

            # loss
            batch_loss_ce_ = torch.mean(
                criterion_ce(batch_x_output.reshape(-1, args.n_quantize),
                             batch_x.reshape(-1)).reshape(
                                 batch_x_output.shape[0], -1), -1)
            batch_loss_err_ = torch.mean(
                torch.sum(
                    100 * criterion_l1(
                        F.softmax(batch_x_output, dim=-1),
                        F.one_hot(batch_x,
                                  num_classes=args.n_quantize).float()), -1),
                -1)

            batch_loss_ce = batch_loss_ce_.mean()
            batch_loss_err = batch_loss_err_.mean()
            total_train_loss["train/loss_ce"].append(batch_loss_ce.item())
            total_train_loss["train/loss_err"].append(batch_loss_err.item())
            loss_ce.append(batch_loss_ce.item())
            loss_err.append(batch_loss_err.item())

            batch_loss += batch_loss_ce_.sum()

            optimizer.zero_grad()
            batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(model_waveform.parameters(), 10)
            optimizer.step()

            logging.info("batch loss [%d] %d %d %d %d %d : %.3f %.3f %% (%.3f sec)" % (c_idx+1, max_slen, x_ss, x_bs, \
                f_ss, f_bs, batch_loss_ce.item(), batch_loss_err.item(), time.time() - start))
            iter_idx += 1
            #if iter_idx % args.save_interval_iter == 0:
            #    logging.info('save iter:%d' % (iter_idx))
            #    save_checkpoint(args.expdir, model_waveform, optimizer, np.random.get_state(), torch.get_rng_state(), iter_idx)
            iter_count += 1
            if iter_idx % args.log_interval_steps == 0:
                logging.info('smt')
                for key in total_train_loss.keys():
                    total_train_loss[key] = np.mean(total_train_loss[key])
                    logging.info(
                        f"(Steps: {iter_idx}) {key} = {total_train_loss[key]:.4f}."
                    )
                write_to_tensorboard(writer, iter_idx, total_train_loss)
                total_train_loss = defaultdict(list)
            total += time.time() - start

    # save final model
    model_waveform.cpu()
    torch.save({"model_waveform": model_waveform.state_dict()},
               args.expdir + "/checkpoint-final.pkl")
    logging.info("final checkpoint created.")
Пример #14
0
class Trainer:
    def __init__(self,
                 model,
                 train_loader,
                 test_loader,
                 epochs=200,
                 batch_size=60,
                 run_id=0,
                 logs_dir='logs',
                 device='cpu',
                 saturation_device=None,
                 optimizer='None',
                 plot=True,
                 compute_top_k=False,
                 data_prallel=False,
                 conv_method='mean'):
        self.saturation_device = device if saturation_device is None else saturation_device
        self.device = device
        self.model = model
        self.epochs = epochs
        self.plot = plot
        self.compute_top_k = compute_top_k

        if 'cuda' in device:
            cudnn.benchmark = True

        self.train_loader = train_loader
        self.test_loader = test_loader

        self.criterion = nn.MSELoss()
        print('Checking for optimizer for {}'.format(optimizer))
        if optimizer == "adam":
            print('Using adam')
            self.optimizer = optim.Adam(model.parameters())
        elif optimizer == "SGD":
            print('Using SGD')
            self.optimizer = optim.SGD(model.parameters(),
                                       lr=0.01,
                                       momentum=0.9)
        elif optimizer == "LRS":
            print('Using LRS')
            self.optimizer = optim.SGD(model.parameters(),
                                       lr=0.01,
                                       momentum=0.9)
            self.lr_scheduler = optim.lr_scheduler.StepLR(self.optimizer, 5)
        elif optimizer == "radam":
            print('Using radam')
            self.optimizer = RAdam(model.parameters())
        else:
            raise ValueError('Unknown optimizer {}'.format(optimizer))
        self.opt_name = optimizer
        save_dir = os.path.join(logs_dir, model.name, train_loader.name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        self.savepath = os.path.join(
            save_dir, f'{model.name}_bs{batch_size}_e{epochs}_id{run_id}.csv')
        self.experiment_done = False
        if os.path.exists(self.savepath):
            trained_epochs = len(pd.read_csv(self.savepath, sep=';'))

            if trained_epochs >= epochs:
                self.experiment_done = True
                print(
                    f'Experiment Logs for the exact same experiment {self.savepath} with identical run_id was detecting, training will be skipped, consider using another run_id'
                )
        self.parallel = data_prallel
        if data_prallel:
            self.model = nn.DataParallel(self.model, ['cuda:0', 'cuda:1'])
        writer = CSVandPlottingWriter(self.savepath.replace('.csv', ''),
                                      fontsize=16,
                                      primary_metric='test_loss')
        self.pooling_strat = conv_method
        self.stats = CheckLayerSat(self.savepath.replace('.csv', ''),
                                   writer,
                                   model,
                                   stats=['lsat'],
                                   sat_threshold=.99,
                                   verbose=False,
                                   conv_method=conv_method,
                                   log_interval=1,
                                   device=self.saturation_device,
                                   reset_covariance=True,
                                   max_samples=None,
                                   ignore_layer_names='classifier666')

    def train(self):
        if self.experiment_done:
            return
        self.model.to(self.device)
        for epoch in range(self.epochs):
            print("{} Epoch {}, training loss: {}".format(
                now(), epoch, self.train_epoch()))
            torch.save({'model_state_dict': self.model.state_dict()},
                       self.savepath.replace('.csv', '.pt'))
            self.test()
            if self.opt_name == "LRS":
                print('LRS step')
                self.lr_scheduler.step()
            self.stats.add_saturations()
            #    plot_saturation_level_from_results(self.savepath, epoch)
        self.stats.close()
        return self.savepath + '.csv'

    def train_epoch(self):
        self.model.train()
        total = 0
        running_loss = 0
        old_time = time()
        for batch, data in enumerate(self.train_loader):
            if batch % 5 == 0 and batch != 0:
                print(batch, 'of', len(self.train_loader), 'processing time',
                      time() - old_time, 'loss:', running_loss / total)
                old_time = time()
            inputs, _ = data
            inputs = inputs.to(self.device)
            total += inputs.size(0)

            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            loss = self.criterion(outputs, inputs)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
        self.stats.add_scalar('training_loss', running_loss / total)
        return running_loss / total

    def test(self):
        self.model.eval()
        total = 0
        test_loss = 0
        with torch.no_grad():
            for batch, data in enumerate(self.test_loader):
                inputs, _ = data
                inputs = inputs.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, inputs)
                _, predicted = torch.max(outputs.data, 1)
                test_loss += loss.item()
                total += inputs.size(0)

        self.stats.add_scalar('test_loss', test_loss / total)
        print('{} Test Loss on {} images: {:.2f}'.format(
            now(), total, test_loss / total))
        torch.save({'model_state_dict': self.model.state_dict()},
                   self.savepath.replace('.csv', '.pt'))
Пример #15
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    parser.add_argument(
        "--meta_path",
        default=None,
        type=str,
        required=False,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        '--classifier',
        default='guoday',
        type=str,
        required=True,
        help='classifier type, guoday or MLP or GRU_MLP or ...')
    parser.add_argument('--optimizer',
                        default='RAdam',
                        type=str,
                        required=True,
                        help='optimizer we use, RAdam or ...')
    parser.add_argument("--do_label_smoothing",
                        default='yes',
                        type=str,
                        required=True,
                        help="Whether to do label smoothing. yes or no.")
    parser.add_argument('--draw_loss_steps',
                        default=1,
                        type=int,
                        required=True,
                        help='training steps to draw loss')
    parser.add_argument('--label_name',
                        default='label',
                        type=str,
                        required=True,
                        help='label name in original train set index')

    ## Other parameters
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument("--do_train",
                        default='yes',
                        type=str,
                        required=True,
                        help="Whether to run training. yes or no.")
    parser.add_argument("--do_test",
                        default='yes',
                        type=str,
                        required=True,
                        help="Whether to run training. yes or no.")
    parser.add_argument("--do_eval",
                        default='yes',
                        type=str,
                        required=True,
                        help="Whether to run eval on the dev set. yes or no.")
    parser.add_argument(
        "--evaluate_during_training",
        action='store_true',
        help="Rul evaluation during training at each logging step.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument("--eval_steps", default=200, type=int, help="")
    parser.add_argument("--lstm_hidden_size", default=300, type=int, help="")
    parser.add_argument("--lstm_layers", default=2, type=int, help="")
    parser.add_argument("--dropout", default=0.5, type=float, help="")

    parser.add_argument("--train_steps", default=-1, type=int, help="")
    parser.add_argument("--report_steps", default=-1, type=int, help="")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--split_num", default=3, type=int, help="text split")
    parser.add_argument('--logging_steps',
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps',
                        type=int,
                        default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action='store_true',
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
    )
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument(
        '--fp16',
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="For distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="For distant debugging.")
    args = parser.parse_args()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank, device, args.n_gpu, bool(args.local_rank != -1),
        args.fp16)

    # Set seed
    set_seed(args)

    try:
        os.makedirs(args.output_dir)
    except:
        pass

    tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path,
                                              do_lower_case=args.do_lower_case)

    # tensorboard_log_dir = args.output_dir

    # loss_now = tf.placeholder(dtype=tf.float32, name='loss_now')
    # loss_mean = tf.placeholder(dtype=tf.float32, name='loss_mean')
    # loss_now_variable = loss_now
    # loss_mean_variable = loss_mean
    # train_loss = tf.summary.scalar('train_loss', loss_now_variable)
    # dev_loss_mean = tf.summary.scalar('dev_loss_mean', loss_mean_variable)
    # merged = tf.summary.merge([train_loss, dev_loss_mean])

    config = BertConfig.from_pretrained(args.model_name_or_path, num_labels=3)
    config.hidden_dropout_prob = args.dropout

    # Prepare model
    if args.do_train == 'yes':
        model = BertForSequenceClassification.from_pretrained(
            args.model_name_or_path, args, config=config)

        if args.fp16:
            model.half()
        model.to(device)
        if args.local_rank != -1:
            try:
                from apex.parallel import DistributedDataParallel as DDP
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            model = DDP(model)
        elif args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)

    if args.do_train == 'yes':
        print(
            '________________________now training______________________________'
        )
        # Prepare data loader

        train_examples = read_examples(os.path.join(args.data_dir,
                                                    'train.csv'),
                                       is_training=True,
                                       label_name=args.label_name)
        train_features = convert_examples_to_features(train_examples,
                                                      tokenizer,
                                                      args.max_seq_length,
                                                      args.split_num, True)
        # print('train_feature_size=', train_features.__sizeof__())
        all_input_ids = torch.tensor(select_field(train_features, 'input_ids'),
                                     dtype=torch.long)
        all_input_mask = torch.tensor(select_field(train_features,
                                                   'input_mask'),
                                      dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(train_features,
                                                    'segment_ids'),
                                       dtype=torch.long)
        all_label = torch.tensor([f.label for f in train_features],
                                 dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label)
        # print('train_data=',train_data[0])
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size //
                                      args.gradient_accumulation_steps)

        num_train_optimization_steps = args.train_steps

        # Prepare optimizer

        param_optimizer = list(model.named_parameters())

        # hack to remove pooler, which is not used
        # thus it produce None grad that break apex
        param_optimizer = [n for n in param_optimizer]

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        if args.optimizer == 'RAdam':
            optimizer = RAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate)
        else:
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              eps=args.adam_epsilon)
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=args.train_steps)

        global_step = 0

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        best_acc = 0
        model.train()
        tr_loss = 0
        loss_batch = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        bar = tqdm(range(num_train_optimization_steps),
                   total=num_train_optimization_steps)
        train_dataloader = cycle(train_dataloader)

        # with tf.Session() as sess:
        #     summary_writer = tf.summary.FileWriter(tensorboard_log_dir, sess.graph)
        #     sess.run(tf.global_variables_initializer())

        list_loss_mean = []
        bx = []
        eval_F1 = []
        ax = []

        for step in bar:
            batch = next(train_dataloader)
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch
            loss = model(input_ids=input_ids,
                         token_type_ids=segment_ids,
                         attention_mask=input_mask,
                         labels=label_ids)

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            tr_loss += loss.item()
            loss_batch += loss.item()
            train_loss = round(
                tr_loss * args.gradient_accumulation_steps / (nb_tr_steps + 1),
                4)

            bar.set_description("loss {}".format(train_loss))
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            if args.fp16:
                # optimizer.backward(loss)
                loss.backward()
            else:
                loss.backward()

            # draw loss every n docs
            if (step + 1) % int(args.draw_loss_steps /
                                (args.train_batch_size /
                                 args.gradient_accumulation_steps)) == 0:
                list_loss_mean.append(round(loss_batch, 4))
                bx.append(step + 1)
                plt.plot(bx,
                         list_loss_mean,
                         label='loss_mean',
                         linewidth=1,
                         color='b',
                         marker='o',
                         markerfacecolor='green',
                         markersize=2)
                plt.savefig(args.output_dir + '/labeled.jpg')
                loss_batch = 0

            # paras update every batch data.
            if (nb_tr_steps + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    # modify learning rate with special warm up BERT uses
                    # if args.fp16 is False, BertAdam is used that handles this automatically
                    lr_this_step = args.learning_rate * warmup_linear.get_lr(
                        global_step, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                scheduler.step()
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

            # report results every 200 real batch.
            if step % (args.eval_steps *
                       args.gradient_accumulation_steps) == 0 and step > 0:
                tr_loss = 0
                nb_tr_examples, nb_tr_steps = 0, 0
                logger.info("***** Report result *****")
                logger.info("  %s = %s", 'global_step', str(global_step))
                logger.info("  %s = %s", 'train loss', str(train_loss))

            # do evaluation totally 10 times during training stage.
            if args.do_eval == 'yes' and (step + 1) % int(
                    num_train_optimization_steps / 10) == 0 and step > 450:
                for file in ['dev.csv']:
                    inference_labels = []
                    gold_labels = []
                    inference_logits = []
                    eval_examples = read_examples(os.path.join(
                        args.data_dir, file),
                                                  is_training=True,
                                                  label_name=args.label_name)
                    eval_features = convert_examples_to_features(
                        eval_examples, tokenizer, args.max_seq_length,
                        args.split_num, False)
                    all_input_ids = torch.tensor(select_field(
                        eval_features, 'input_ids'),
                                                 dtype=torch.long)
                    all_input_mask = torch.tensor(select_field(
                        eval_features, 'input_mask'),
                                                  dtype=torch.long)
                    all_segment_ids = torch.tensor(select_field(
                        eval_features, 'segment_ids'),
                                                   dtype=torch.long)
                    all_label = torch.tensor([f.label for f in eval_features],
                                             dtype=torch.long)

                    eval_data = TensorDataset(all_input_ids, all_input_mask,
                                              all_segment_ids, all_label)

                    logger.info("***** Running evaluation *****")
                    logger.info("  Num examples = %d", len(eval_examples))
                    logger.info("  Batch size = %d", args.eval_batch_size)

                    # Run prediction for full data
                    eval_sampler = SequentialSampler(eval_data)
                    eval_dataloader = DataLoader(
                        eval_data,
                        sampler=eval_sampler,
                        batch_size=args.eval_batch_size)

                    model.eval()
                    eval_loss, eval_accuracy = 0, 0
                    nb_eval_steps, nb_eval_examples = 0, 0
                    for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        label_ids = label_ids.to(device)

                        with torch.no_grad():
                            tmp_eval_loss = model(input_ids=input_ids,
                                                  token_type_ids=segment_ids,
                                                  attention_mask=input_mask,
                                                  labels=label_ids)
                            logits = model(input_ids=input_ids,
                                           token_type_ids=segment_ids,
                                           attention_mask=input_mask)

                        logits = logits.detach().cpu().numpy()
                        label_ids = label_ids.to('cpu').numpy()
                        inference_labels.append(np.argmax(logits, axis=1))
                        gold_labels.append(label_ids)
                        inference_logits.append(logits)
                        eval_loss += tmp_eval_loss.mean().item()
                        nb_eval_examples += input_ids.size(0)
                        nb_eval_steps += 1

                    gold_labels = np.concatenate(gold_labels, 0)
                    inference_labels = np.concatenate(inference_labels, 0)
                    inference_logits = np.concatenate(inference_logits, 0)
                    model.train()
                    ###############################################
                    num_gold_0 = np.sum(gold_labels == 0)
                    num_gold_1 = np.sum(gold_labels == 1)
                    num_gold_2 = np.sum(gold_labels == 2)

                    right_0 = 0
                    right_1 = 0
                    right_2 = 0
                    error_0 = 0
                    error_1 = 0
                    error_2 = 0

                    for gold_label, inference_label in zip(
                            gold_labels, inference_labels):
                        if gold_label == inference_label:
                            if gold_label == 0:
                                right_0 += 1
                            elif gold_label == 1:
                                right_1 += 1
                            else:
                                right_2 += 1
                        elif inference_label == 0:
                            error_0 += 1
                        elif inference_label == 1:
                            error_1 += 1
                        else:
                            error_2 += 1

                    recall_0 = right_0 / (num_gold_0 + 1e-5)
                    recall_1 = right_1 / (num_gold_1 + 1e-5)
                    recall_2 = right_2 / (num_gold_2 + 1e-5)
                    precision_0 = right_0 / (error_0 + right_0 + 1e-5)
                    precision_1 = right_1 / (error_1 + right_1 + 1e-5)
                    precision_2 = right_2 / (error_2 + right_2 + 1e-5)
                    f10 = 2 * precision_0 * recall_0 / (precision_0 +
                                                        recall_0 + 1e-5)
                    f11 = 2 * precision_1 * recall_1 / (precision_1 +
                                                        recall_1 + 1e-5)
                    f12 = 2 * precision_2 * recall_2 / (precision_2 +
                                                        recall_2 + 1e-5)

                    output_dev_result_file = os.path.join(
                        args.output_dir, "dev_results.txt")
                    with open(output_dev_result_file, 'a',
                              encoding='utf-8') as f:
                        f.write('precision:' + str(precision_0) + ' ' +
                                str(precision_1) + ' ' + str(precision_2) +
                                '\n')
                        f.write('recall:' + str(recall_0) + ' ' +
                                str(recall_1) + ' ' + str(recall_2) + '\n')
                        f.write('f1:' + str(f10) + ' ' + str(f11) + ' ' +
                                str(f12) + '\n' + '\n')

                    eval_loss = eval_loss / nb_eval_steps
                    eval_accuracy = accuracy(inference_logits, gold_labels)
                    # draw loss.
                    eval_F1.append(round(eval_accuracy, 4))
                    ax.append(step)
                    plt.plot(ax,
                             eval_F1,
                             label='eval_F1',
                             linewidth=1,
                             color='r',
                             marker='o',
                             markerfacecolor='blue',
                             markersize=2)
                    for a, b in zip(ax, eval_F1):
                        plt.text(a, b, b, ha='center', va='bottom', fontsize=8)
                    plt.savefig(args.output_dir + '/labeled.jpg')

                    result = {
                        'eval_loss': eval_loss,
                        'eval_F1': eval_accuracy,
                        'global_step': global_step,
                        'loss': train_loss
                    }

                    output_eval_file = os.path.join(args.output_dir,
                                                    "eval_results.txt")
                    with open(output_eval_file, "a") as writer:
                        for key in sorted(result.keys()):
                            logger.info("  %s = %s", key, str(result[key]))
                            writer.write("%s = %s\n" % (key, str(result[key])))
                        writer.write('*' * 80)
                        writer.write('\n')
                    if eval_accuracy > best_acc and 'dev' in file:
                        print("=" * 80)
                        print("more accurate model arises, now best F1 = ",
                              eval_accuracy)
                        print("Saving Model......")
                        best_acc = eval_accuracy
                        # Save a trained model, only save the model it-self
                        model_to_save = model.module if hasattr(
                            model, 'module') else model
                        output_model_file = os.path.join(
                            args.output_dir, "pytorch_model.bin")
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        print("=" * 80)
                    '''
                    if (step+1) / int(num_train_optimization_steps/10) > 9.5:
                        print("=" * 80)
                        print("End of training. Saving Model......")
                        # Save a trained model, only save the model it-self
                        model_to_save = model.module if hasattr(model, 'module') else model
                        output_model_file = os.path.join(args.output_dir, "pytorch_model_final_step.bin")
                        torch.save(model_to_save.state_dict(), output_model_file)
                        print("=" * 80)
                    '''

    if args.do_test == 'yes':
        start_time = time.time()
        print(
            '___________________now testing for best eval f1 model_________________________'
        )
        try:
            del model
        except:
            pass
        gc.collect()
        args.do_train = 'no'
        model = BertForSequenceClassification.from_pretrained(os.path.join(
            args.output_dir, "pytorch_model.bin"),
                                                              args,
                                                              config=config)
        model.half()
        for layer in model.modules():
            if isinstance(layer, torch.nn.modules.batchnorm._BatchNorm):
                layer.float()
        model.to(device)
        if args.local_rank != -1:
            try:
                from apex.parallel import DistributedDataParallel as DDP
            except ImportError:
                raise ImportError(
                    "Please install apex from "
                    "https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            model = DDP(model)
        elif args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        for file, flag in [('test.csv', 'test')]:
            inference_labels = []
            gold_labels = []
            eval_examples = read_examples(os.path.join(args.data_dir, file),
                                          is_training=False,
                                          label_name=args.label_name)
            eval_features = convert_examples_to_features(
                eval_examples, tokenizer, args.max_seq_length, args.split_num,
                False)
            all_input_ids = torch.tensor(select_field(eval_features,
                                                      'input_ids'),
                                         dtype=torch.long)
            all_input_mask = torch.tensor(select_field(eval_features,
                                                       'input_mask'),
                                          dtype=torch.long)
            all_segment_ids = torch.tensor(select_field(
                eval_features, 'segment_ids'),
                                           dtype=torch.long)
            all_label = torch.tensor([f.label for f in eval_features],
                                     dtype=torch.long)

            eval_data = TensorDataset(all_input_ids, all_input_mask,
                                      all_segment_ids, all_label)
            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data,
                                         sampler=eval_sampler,
                                         batch_size=args.eval_batch_size)

            model.eval()
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0
            for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)

                with torch.no_grad():
                    logits = model(
                        input_ids=input_ids,
                        token_type_ids=segment_ids,
                        attention_mask=input_mask).detach().cpu().numpy()
                    # print('test_logits=', logits)
                label_ids = label_ids.to('cpu').numpy()
                inference_labels.append(logits)
                gold_labels.append(label_ids)
            gold_labels = np.concatenate(gold_labels, 0)
            logits = np.concatenate(inference_labels, 0)
            if flag == 'dev':
                print(flag, accuracy(logits, gold_labels))
            elif flag == 'test':
                df = pd.read_csv(os.path.join(args.data_dir, file))
                df['label_0'] = logits[:, 0]
                df['label_1'] = logits[:, 1]
                df['label_2'] = logits[:, 2]
                df[['id', 'label_0', 'label_1',
                    'label_2']].to_csv(os.path.join(args.output_dir,
                                                    "sub.csv"),
                                       index=False)
                # df[['id', 'label_0', 'label_1']].to_csv(os.path.join(args.output_dir, "sub.csv"), index=False)
            else:
                raise ValueError('flag not in [dev, test]')
        print('inference time usd = {}s'.format(time.time() - start_time))
        '''
Пример #16
0
def train_ccblock(model_options):
    # get train&valid datasets' paths
    if model_options.trainset_num > 1:
        train_file_paths = [
            model_options.trainset_path.format(i)
            for i in range(1, model_options.trainset_num + 1)
        ]
    else:
        train_file_paths = [model_options.trainset_path]

    # load datasets
    print(train_file_paths)
    label_paths = "/home/langruimin/BLSTM_pytorch/data/fcv/fcv_train_labels.mat"
    videoset = VideoDataset(train_file_paths, label_paths)
    print(len(videoset))

    # create model
    model = RCCAModule(1, 2)

    model_quan = Quantization(16, 256, 1024)

    params_path = os.path.join(model_options.model_save_path,
                               model_options.params_filename)
    params_path_Q = os.path.join(model_options.model_save_path,
                                 model_options.Qparams_filename)
    if model_options.reload_params:
        print('Loading model params...')
        model.load_state_dict(torch.load(params_path))
        print('Done.')

    model = model.cuda()
    model_quan = model_quan.cuda()
    # optimizer
    optimizer = RAdam(model.parameters(),
                      lr=1e-4,
                      betas=(0.9, 0.999),
                      weight_decay=1e-4)
    optimizer2 = RAdam(model_quan.parameters(),
                       lr=1e-3,
                       betas=(0.9, 0.999),
                       weight_decay=1e-4)

    lr_C = ''
    lr_Q = ''

    # load the similarity matrix
    print("+++++++++loading similarity+++++++++")
    f = open(
        "/home/langruimin/BLSTM_pytorch/data/fcv/SimilarityInfo/Sim_K1_10_K2_5_fcv.pkl",
        "rb")
    similarity = pkl.load(f)
    similarity = torch.ByteTensor(similarity.astype(np.uint8))
    f.close()
    print("++++++++++similarity loaded+++++++")
    # '''

    batch_idx = 1
    train_loss_rec = open(
        os.path.join(model_options.records_save_path,
                     model_options.train_loss_filename), 'w')
    error_ = 0.
    loss_ = 0.
    num = 0
    neighbor_num = 0
    neighbor = True
    neighbor_freq = 2
    total_batchs = len(videoset) // model_options.train_batch_size
    print("##########start train############")
    trainloader = torch.utils.data.DataLoader(videoset,
                                              batch_size=8,
                                              shuffle=True,
                                              num_workers=4,
                                              pin_memory=True)
    model.train()
    model_quan.train()

    neighbor_loss = 0.0
    for l in range(80):

        if neighbor == True:
            # training
            for i, (data, index, _, _) in enumerate(trainloader):
                data = data.to(model_options.default_dtype)
                data = data.unsqueeze(1)
                data = data.cuda()

                output_ccblock_mean = torch.tanh(model(data))

                # quantization block
                Qhard, Qsoft, SoftDistortion, HardDistortion, JointCenter, error, _ = model_quan(
                    output_ccblock_mean)
                Q_loss = 0.1 * SoftDistortion + HardDistortion + 0.1 * JointCenter

                optimizer2.zero_grad()
                Q_loss.backward(retain_graph=True)
                optimizer2.step()

                # if batch_idx < total_batchs * 0.6:
                # if l % neighbor_freq == 0:
                # neighbor loss
                similarity_select = torch.index_select(similarity, 0, index)
                similarity_select = torch.index_select(similarity_select, 1,
                                                       index).float().cuda()
                neighbor_loss = torch.sum(
                    (torch.mm(output_ccblock_mean,
                              output_ccblock_mean.transpose(0, 1)) /
                     output_ccblock_mean.shape[-1] - similarity_select).pow(2))

                loss_ += neighbor_loss.item()
                neighbor_num += 1

                optimizer.zero_grad()
                neighbor_loss.backward()
                optimizer.step()

                error_ += error.item()
                num += 1
                if batch_idx % model_options.disp_freq == 0:
                    info = "epoch{0} Batch {1} loss:{2:.3f}  distortion:{3:.3f} " \
                        .format(l, batch_idx, loss_/ neighbor_num, error_ / num)
                    print(info)
                    train_loss_rec.write(info + '\n')

                batch_idx += 1
            batch_idx = 0
            error_ = 0.
            loss_ = 0.
            num = 0
            neighbor_num = 0

        if (l + 1) % model_options.save_freq == 0:
            print('epoch: ', l, 'New best model. Saving model ...')
            torch.save(model.state_dict(), params_path)
            torch.save(model_quan.state_dict(), params_path_Q)

            for param_group in optimizer.param_groups:
                lr_C = param_group['lr']
            for param_group in optimizer2.param_groups:
                lr_Q = param_group['lr']
            record_inf = "saved model at epoch {0} lr_C:{1} lr_Q:{2}".format(
                l, lr_C, lr_Q)
            train_loss_rec.write(record_inf + '\n')
        print("##########epoch done##########")

    print('train done. Saving model ...')
    torch.save(model.state_dict(), params_path)
    torch.save(model_quan.state_dict(), params_path_Q)
    print("##########train done##########")
Пример #17
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--bs', metavar='bs', type=int, default=2)
    parser.add_argument('--path', type=str, default='../../data')
    parser.add_argument('--results', type=str, default='../../results/model')
    parser.add_argument('--nw', type=int, default=0)
    parser.add_argument('--max_images', type=int, default=None)
    parser.add_argument('--val_size', type=int, default=None)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--lr', type=float, default=0.003)
    parser.add_argument('--lr_decay', type=float, default=0.99997)
    parser.add_argument('--kernel_lvl', type=float, default=1)
    parser.add_argument('--noise_lvl', type=float, default=1)
    parser.add_argument('--motion_blur', type=bool, default=False)
    parser.add_argument('--homo_align', type=bool, default=False)
    parser.add_argument('--resume', type=bool, default=False)

    args = parser.parse_args()

    print()
    print(args)
    print()

    if not os.path.isdir(args.results): os.makedirs(args.results)

    PATH = args.results
    if not args.resume:
        f = open(PATH + "/param.txt", "a+")
        f.write(str(args))
        f.close()

    writer = SummaryWriter(PATH + '/runs')

    # CUDA for PyTorch
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else "cpu")

    # Parameters
    params = {'batch_size': args.bs, 'shuffle': True, 'num_workers': args.nw}

    # Generators
    print('Initializing training set')
    training_set = Dataset(args.path + '/train/', args.max_images,
                           args.kernel_lvl, args.noise_lvl, args.motion_blur,
                           args.homo_align)
    training_generator = data.DataLoader(training_set, **params)

    print('Initializing validation set')
    validation_set = Dataset(args.path + '/test/', args.val_size,
                             args.kernel_lvl, args.noise_lvl, args.motion_blur,
                             args.homo_align)

    validation_generator = data.DataLoader(validation_set, **params)

    # Model
    model = UNet(in_channel=3, out_channel=3)
    if args.resume:
        models_path = get_newest_model(PATH)
        print('loading model from ', models_path)
        model.load_state_dict(torch.load(models_path))

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)

    model.to(device)

    # Loss + optimizer
    criterion = BurstLoss()
    optimizer = RAdam(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=8 // args.bs, gamma=args.lr_decay)
    if args.resume:
        n_iter = np.loadtxt(PATH + '/train.txt', delimiter=',')[:, 0][-1]
    else:
        n_iter = 0

    # Loop over epochs
    for epoch in range(args.epochs):
        train_loss = 0.0

        # Training
        model.train()
        for i, (X_batch, y_labels) in enumerate(training_generator):
            # Alter the burst length for each mini batch

            burst_length = np.random.randint(2, 9)
            X_batch = X_batch[:, :burst_length, :, :, :]

            # Transfer to GPU
            X_batch, y_labels = X_batch.to(device).type(
                torch.float), y_labels.to(device).type(torch.float)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            pred = model(X_batch)
            loss = criterion(pred, y_labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss += loss.detach().cpu().numpy()
            writer.add_scalar('training_loss', loss.item(), n_iter)

            if i % 100 == 0 and i > 0:
                loss_printable = str(np.round(train_loss, 2))

                f = open(PATH + "/train.txt", "a+")
                f.write(str(n_iter) + "," + loss_printable + "\n")
                f.close()

                print("training loss ", loss_printable)

                train_loss = 0.0

            if i % 1000 == 0:
                if torch.cuda.device_count() > 1:
                    torch.save(
                        model.module.state_dict(),
                        os.path.join(PATH,
                                     'model_' + str(int(n_iter)) + '.pt'))
                else:
                    torch.save(
                        model.state_dict(),
                        os.path.join(PATH,
                                     'model_' + str(int(n_iter)) + '.pt'))

            if i % 1000 == 0:
                # Validation
                val_loss = 0.0
                with torch.set_grad_enabled(False):
                    model.eval()
                    for v, (X_batch,
                            y_labels) in enumerate(validation_generator):
                        # Alter the burst length for each mini batch

                        burst_length = np.random.randint(2, 9)
                        X_batch = X_batch[:, :burst_length, :, :, :]

                        # Transfer to GPU
                        X_batch, y_labels = X_batch.to(device).type(
                            torch.float), y_labels.to(device).type(torch.float)

                        # forward + backward + optimize
                        pred = model(X_batch)
                        loss = criterion(pred, y_labels)

                        val_loss += loss.detach().cpu().numpy()

                        if v < 5:
                            im = make_im(pred, X_batch, y_labels)
                            writer.add_image('image_' + str(v), im, n_iter)

                    writer.add_scalar('validation_loss', val_loss, n_iter)

                    loss_printable = str(np.round(val_loss, 2))
                    print('validation loss ', loss_printable)

                    f = open(PATH + "/eval.txt", "a+")
                    f.write(str(n_iter) + "," + loss_printable + "\n")
                    f.close()

            n_iter += args.bs
Пример #18
0
class ECGTrainer(object):
    def __init__(self, pre_trained=None, block_config='small', num_threads=2):
        torch.set_num_threads(num_threads)
        self.n_epochs = 60
        self.batch_size = 128
        self.scheduler = None
        self.pre_trained = pre_trained
        self.num_threads = num_threads
        self.cuda = torch.cuda.is_available()

        if block_config == 'small':
            self.block_config = (3, 6, 12, 8)
        else:
            self.block_config = (6, 12, 24, 16)

        self.__build_model()
        self.__build_criterion()
        self.__build_optimizer()
        self.__build_scheduler()
        return

    def __build_model(self):
        if self.pre_trained is not None:
            self.model = DenseNet(num_classes=55,
                                  block_config=self.block_config)
            in_features = self.model.classifier.in_features
            self.model.classifier = torch.nn.Linear(in_features, 34)
        else:
            self.model = DenseNet(num_classes=34,
                                  block_config=self.block_config)
        if self.cuda:
            self.model.cuda()
        return

    def __build_criterion(self):
        self.criterion = ComboLoss(losses=['mlsml', 'f1', 'focal'],
                                   weights=[1, 1, 3])
        return

    def __build_optimizer(self):
        lr = 1e-3 if self.pre_trained is not None else 1e-2
        opt_params = {
            'lr': lr,
            'weight_decay': 0.0,
            'params': self.model.parameters()
        }
        self.optimizer = RAdam(**opt_params)
        return

    def __build_scheduler(self):
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                              'max',
                                                              factor=0.1,
                                                              patience=5,
                                                              verbose=True,
                                                              min_lr=1e-5)
        return

    def run(self, trainset, validset, model_dir):
        print('=' * 100 + '\n' + 'TRAINING MODEL\n' + '-' * 100 + '\n')
        model_path = os.path.join(model_dir, 'model.pth')
        thresh_path = os.path.join(model_dir, 'threshold.npy')

        dataloader = {
            'train':
            ECGLoader(trainset, self.batch_size, True,
                      self.num_threads).build(),
            'valid':
            ECGLoader(validset, 64, False, self.num_threads).build()
        }

        best_metric, best_preds = None, None
        for epoch in range(self.n_epochs):
            e_message = '[EPOCH {:0=3d}/{:0=3d}]'.format(
                epoch + 1, self.n_epochs)

            for phase in ['train', 'valid']:
                ep_message = e_message + '[' + phase.upper() + ']'
                if phase == 'train':
                    self.model.train()
                else:
                    self.model.eval()

                losses, preds, labels = [], [], []
                batch_num = len(dataloader[phase])
                for ith_batch, data in enumerate(dataloader[phase]):
                    ecg, label = [d.cuda()
                                  for d in data] if self.cuda else data

                    pred = self.model(ecg)
                    loss = self.criterion(pred, label)
                    if phase == 'train':
                        self.optimizer.zero_grad()
                        loss.backward()
                        self.optimizer.step()

                    pred = torch.sigmoid(pred)
                    pred = pred.data.cpu().numpy()
                    label = label.data.cpu().numpy()

                    bin_pred = np.copy(pred)
                    bin_pred[bin_pred > 0.5] = 1
                    bin_pred[bin_pred <= 0.5] = 0
                    f1 = f1_score(label.flatten(), bin_pred.flatten())

                    losses.append(loss.item())
                    preds.append(pred)
                    labels.append(label)

                    sr_message = '[STEP {:0=3d}/{:0=3d}]-[Loss: {:.6f} F1: {:.6f}]'
                    sr_message = ep_message + sr_message
                    print(sr_message.format(ith_batch + 1, batch_num, loss,
                                            f1),
                          end='\r')

                preds = np.concatenate(preds, axis=0)
                labels = np.concatenate(labels, axis=0)
                bin_preds = np.copy(preds)
                bin_preds[bin_preds > 0.5] = 1
                bin_preds[bin_preds <= 0.5] = 0

                avg_loss = np.mean(losses)
                avg_f1 = f1_score(labels.flatten(), bin_preds.flatten())
                er_message = '-----[Loss: {:.6f} F1: {:.6f}]'
                er_message = '\n\033[94m' + ep_message + er_message + '\033[0m'
                print(er_message.format(avg_loss, avg_f1))

                if phase == 'valid':
                    if self.scheduler is not None:
                        self.scheduler.step(avg_f1)
                    if best_metric is None or best_metric < avg_f1:
                        best_metric = avg_f1
                        best_preds = [labels, preds]
                        best_loss_metrics = [epoch + 1, avg_loss, avg_f1]
                        torch.save(self.model.state_dict(), model_path)
                        print('[Best validation metric, model: {}]'.format(
                            model_path))
                    print()

        best_f1, best_th = best_f1_score(*best_preds)
        np.save(thresh_path, np.array(best_th))
        print('[Searched Best F1: {:.6f}]\n'.format(best_f1))
        res_message = '[VALIDATION PERFORMANCE: BEST F1]' + '\n' \
            + '[EPOCH:{} LOSS:{:.6f} F1:{:.6f} BEST F1:{:.6f}]\n'.format(
                best_loss_metrics[0], best_loss_metrics[1],
                best_loss_metrics[2], best_f1) \
            + '[BEST THRESHOLD:\n{}]\n'.format(best_th) \
            + '=' * 100 + '\n'
        print(res_message)
        return
Пример #19
0
class Trainer():
    def __init__(self, log_dir, cfg):

        self.path = log_dir
        self.cfg = cfg

        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(self.path, 'Model')
            self.log_dir = os.path.join(self.path, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.log_dir)
            self.writer = SummaryWriter(log_dir=self.log_dir)
            sys.stdout = Logger(logfile=os.path.join(self.path, "logfile.log"))

        self.data_dir = cfg.DATASET.DATA_DIR
        self.max_epochs = cfg.TRAIN.MAX_EPOCHS
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)

        self.batch_size = cfg.TRAIN.BATCH_SIZE
        self.lr = cfg.TRAIN.LEARNING_RATE

        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        # load dataset
        cogent = cfg.DATASET.congent
        if not args.eval and not args.test:
        self.dataset = ClevrDataset(data_dir=self.data_dir, split="train" + cogent)
        self.dataloader = DataLoader(dataset=self.dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True,
                                       num_workers=cfg.WORKERS, drop_last=False, collate_fn=collate_fn)

        self.dataset_val = ClevrDataset(data_dir=self.data_dir, split="val" + cogent)
        self.dataloader_val = DataLoader(dataset=self.dataset_val, batch_size=256, drop_last=False,
                                         shuffle=False, num_workers=cfg.WORKERS, collate_fn=collate_fn)

        # load model
        self.vocab = load_vocab(cfg)
        self.model, self.model_ema = mac.load_MAC(cfg, self.vocab)
        self.weight_moving_average(alpha=0)
        if cfg.TRAIN.RADAM:
            self.optimizer = RAdam(self.model.parameters(), lr=self.lr)
        else:
            self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

        self.previous_best_acc = 0.0
        self.previous_best_epoch = 0

        self.total_epoch_loss = 0
        self.prior_epoch_loss = 10

        self.print_info()
        self.loss_fn = torch.nn.CrossEntropyLoss().cuda()

    def print_info(self):
        print('Using config:')
        pprint.pprint(self.cfg)
        print("\n")

        pprint.pprint("Size of dataset: {}".format(len(self.dataset)))
        print("\n")

        print("Using MAC-Model:")
        pprint.pprint(self.model)
        print("\n")

    def weight_moving_average(self, alpha=0.999):
        for param1, param2 in zip(self.model_ema.parameters(), self.model.parameters()):
            param1.data *= alpha
            param1.data += (1.0 - alpha) * param2.data

    def set_mode(self, mode="train"):
        if mode == "train":
            self.model.train()
            self.model_ema.train()
        else:
            self.model.eval()
            self.model_ema.eval()

    def reduce_lr(self):
        epoch_loss = self.total_epoch_loss # / float(len(self.dataset) // self.batch_size)
        lossDiff = self.prior_epoch_loss - epoch_loss
        if ((lossDiff < 0.015 and self.prior_epoch_loss < 0.5 and self.lr > 0.00002) or \
            (lossDiff < 0.008 and self.prior_epoch_loss < 0.15 and self.lr > 0.00001) or \
            (lossDiff < 0.003 and self.prior_epoch_loss < 0.10 and self.lr > 0.000005)):
            self.lr *= 0.5
            print("Reduced learning rate to {}".format(self.lr))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.lr
        self.prior_epoch_loss = epoch_loss
        self.total_epoch_loss = 0

    def save_models(self, iteration):
        save_model(self.model, self.optimizer, iteration, self.model_dir, model_name="model")
        save_model(self.model_ema, None, iteration, self.model_dir, model_name="model_ema")

    def train_epoch(self, epoch):
        cfg = self.cfg
        total_loss = 0
        total_correct = 0
        total_samples = 0

        self.labeled_data = iter(self.dataloader)
        self.set_mode("train")

        dataset = tqdm(self.labeled_data, total=len(self.dataloader))

        for data in dataset:
            ######################################################
            # (1) Prepare training data
            ######################################################
            image, question, question_len, answer = data['image'], data['question'], data['question_length'], data['answer']
            answer = answer.long()
            question = Variable(question)
            answer = Variable(answer)

            if cfg.CUDA:
                image = image.cuda()
                question = question.cuda()
                answer = answer.cuda().squeeze()
            else:
                question = question
                image = image
                answer = answer.squeeze()

            ############################
            # (2) Train Model
            ############################
            self.optimizer.zero_grad()

            scores = self.model(image, question, question_len)
            loss = self.loss_fn(scores, answer)
            loss.backward()

            if self.cfg.TRAIN.CLIP_GRADS:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.TRAIN.CLIP)

            self.optimizer.step()
            self.weight_moving_average()

            ############################
            # (3) Log Progress
            ############################
            correct = scores.detach().argmax(1) == answer
            total_correct += correct.sum().cpu().item()
            total_loss += loss.item() * answer.size(0)
            total_samples += answer.size(0)

            avg_loss = total_loss / total_samples
            train_accuracy = total_correct / total_samples
            # accuracy = correct.sum().cpu().numpy() / answer.shape[0]

            # if avg_loss == 0:
            #     avg_loss = loss.item()
            #     train_accuracy = accuracy
            # else:
            #     avg_loss = 0.99 * avg_loss + 0.01 * loss.item()
            #     train_accuracy = 0.99 * train_accuracy + 0.01 * accuracy
            # self.total_epoch_loss += loss.item() * answer.size(0)

            dataset.set_description(
                'Epoch: {}; Avg Loss: {:.5f}; Avg Train Acc: {:.5f}'.format(epoch + 1, avg_loss, train_accuracy)
            )

        self.total_epoch_loss = avg_loss
        print(self.total_epoch_loss)

        dict = {
            "avg_loss": avg_loss,
            "train_accuracy": train_accuracy
        }
        return dict

    def train(self):
        cfg = self.cfg
        print("Start Training")
        for epoch in range(self.max_epochs):
            dict = self.train_epoch(epoch)
            self.reduce_lr()
            self.log_results(epoch, dict)
            if cfg.TRAIN.EALRY_STOPPING:
                if epoch - cfg.TRAIN.PATIENCE == self.previous_best_epoch:
                    break

        self.save_models(self.max_epochs)
        self.writer.close()
        print("Finished Training")
        print("Highest validation accuracy: {} at epoch {}")

    def log_results(self, epoch, dict, max_eval_samples=None):
        epoch += 1
        self.writer.add_scalar("avg_loss", dict["avg_loss"], epoch)
        self.writer.add_scalar("train_accuracy", dict["train_accuracy"], epoch)

        val_accuracy, val_accuracy_ema = self.calc_accuracy("validation", max_samples=max_eval_samples)
        self.writer.add_scalar("val_accuracy_ema", val_accuracy_ema, epoch)
        self.writer.add_scalar("val_accuracy", val_accuracy, epoch)

        print("Epoch: {}\tVal Acc: {},\tVal Acc EMA: {},\tAvg Loss: {},\tLR: {}".
              format(epoch, val_accuracy, val_accuracy_ema, dict["avg_loss"], self.lr))

        if val_accuracy > self.previous_best_acc:
            self.previous_best_acc = val_accuracy
            self.previous_best_epoch = epoch

        if epoch % self.snapshot_interval == 0:
            self.save_models(epoch)

    def calc_accuracy(self, mode="train", max_samples=None):
        self.set_mode("validation")

        if mode == "train":
            loader = self.dataloader
            # num_imgs = len(self.dataset)
        elif mode == "validation":
            loader = self.dataloader_val
            # num_imgs = len(self.dataset_val)

        # batch_size = 256
        # total_iters = num_imgs // batch_size
        # if max_samples is not None:
        #     max_iter = max_samples // batch_size
        # else:
        #     max_iter = None

        # all_accuracies = []
        total_correct = 0
        total_correct_ema = 0
        total_samples = 0
        # all_accuracies_ema = []

        for data in tqdm(loader, total=len(loader)):
            # try:
            #     data = next(eval_data)
            # except StopIteration:
            #     break
            # if max_iter is not None and _iteration == max_iter:
            #     break

            image, question, question_len, answer = data['image'], data['question'], data['question_length'], data['answer']
            answer = answer.long()
            question = Variable(question)
            answer = Variable(answer)

            if self.cfg.CUDA:
                image = image.cuda()
                question = question.cuda()
                answer = answer.cuda().squeeze()

            with torch.no_grad():
                scores = self.model(image, question, question_len)
                scores_ema = self.model_ema(image, question, question_len)

            correct_ema = scores_ema.detach().argmax(1) == answer
            total_correct_ema += correct_ema.sum().cpu().item()
            # accuracy_ema = correct_ema.sum().cpu().numpy() / answer.shape[0]

            # all_accuracies_ema.append(accuracy_ema)

            correct = scores.detach().argmax(1) == answer
            total_correct += correct.sum().cpu().item()
            # accuracy = correct.sum().cpu().numpy() / answer.shape[0]
            # all_accuracies.append(accuracy)
            total_samples += answer.size(0)

        accuracy_ema = total_correct_ema / total_samples
        accuracy = total_correct / total_samples

        return accuracy, accuracy_ema
Пример #20
0
def train(args):
    # augmentations
    train_transform = Compose([

        Resize(args.img_size, args.img_size),
        Cutout(num_holes=8, max_h_size=20, max_w_size=20, fill_value=0, always_apply=False, p=0.5),
        Normalize(
                mean=[0.0692],
                std=[0.205],
            ),
        ToTensorV2()
    ])
    val_transform = Compose([
        Resize(args.img_size, args.img_size),
        Normalize(
            mean=[0.0692],
            std=[0.205],
        ),
        ToTensorV2() 
    ])

    
    # Load data
    df_train = pd.read_csv("../input/train_folds.csv")

    if args.fold == -1: 
        sys.exit()


    train = df_train[df_train['kfold']!=args.fold].reset_index(drop=True)#[:1000]
    val = df_train[df_train['kfold']==args.fold].reset_index(drop=True)#[:1000]

    train_data = ImageDataset('../input/images', train_transform, train)
    train_loader = utils.DataLoader(train_data, shuffle=True, num_workers=5, batch_size=args.batch_size, pin_memory=True)

    val_data = ImageDataset('../input/images', val_transform, val)
    val_loader = utils.DataLoader(val_data, shuffle=False, num_workers=5, batch_size=args.batch_size, pin_memory=True)   

# create model 

    device = torch.device(f"cuda:{args.gpu_n}")
    model = PretrainedCNN()
    
    
    if args.pretrain_path != "":
        model.load_state_dict(torch.load(args.pretrain_path, map_location=f"cuda:{args.gpu_n}"))
        print("weights loaded")
    model.to(device)
    
    
    
    optimizer = RAdam(model.parameters(), lr=args.start_lr)     

    opt_level = 'O1'
    model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
    loss_fn = nn.CrossEntropyLoss()
    scheduler = ReduceLROnPlateau(optimizer, mode='max', verbose=True, patience=8, factor=0.6)

    best_models = deque(maxlen=5)
    best_score = 0.99302


    for e in range(args.epoch):

        # Training:
        train_loss = []
        model.train()

        for image, target in tqdm(train_loader, ncols = 70):   
            optimizer.zero_grad()
            xs = image.to(device)
            ys = target.to(device)

            # Cutmix using with BUG
            if np.random.rand()<0.5:
                 images, targets = cutmix(xs, ys[:,0], ys[:,1], ys[:,2], 1.0)
                 pred = model(xs)
                 output1 = pred[:,:168]
                 output2 = pred[:,168:179]
                 output3 = pred[:,179:]
                 loss = cutmix_criterion(output1,output2,output3, targets)

            else:
                pred = model(xs)
                grapheme = pred[:,:168]
                vowel = pred[:,168:179]
                cons = pred[:,179:]

                loss = loss_fn(grapheme, ys[:,0]) + loss_fn(vowel, ys[:,1])+ loss_fn(cons, ys[:,2])

            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()

            optimizer.step()
            train_loss.append(loss.item())

        
        #Validation    
        val_loss = []
        val_true = []
        val_pred = []
        model.eval()  
        with torch.no_grad():
            for image, target in val_loader:#tqdm(val_loader, ncols=50):
                xs = image.to(device)
                ys = target.to(device)

                pred = model(xs)
                grapheme = pred[:,:168]
                vowel = pred[:,168:179]
                cons = pred[:,179:]

                loss = loss_fn(grapheme, ys[:,0]) + loss_fn(vowel, ys[:,1])+ loss_fn(cons, ys[:,2])
                val_loss.append(loss.item())

                grapheme = grapheme.cpu().argmax(dim=1).data.numpy()
                vowel = vowel.cpu().argmax(dim=1).data.numpy()
                cons = cons.cpu().argmax(dim=1).data.numpy()
                val_true.append(target.numpy())
                val_pred.append(np.stack([grapheme, vowel, cons], axis=1))

        val_true = np.concatenate(val_true)
        val_pred = np.concatenate(val_pred)

        val_loss = np.mean(val_loss)
        train_loss = np.mean(train_loss)
        scores = []

        for i in [0,1,2]:
            scores.append(sklearn.metrics.recall_score(val_true[:,i], val_pred[:,i], average='macro'))
        final_score = np.average(scores, weights=[2,1,1])


        print(f'Epoch: {e:03d}; train_loss: {train_loss:.05f}; val_loss: {val_loss:.05f}; ', end='')
        print(f'score: {final_score:.5f} ', end='')

    
    #   Checkpoint model. If there are 2nd stage(224x224) save best 5 checkpoints
        if final_score > best_score:
            best_score = final_score
            state_dict = copy.deepcopy(model.state_dict()) 
            if args.save_queue==1:
                best_models.append(state_dict)
                for i, m in enumerate(best_models):
                    path = f"models/{args.exp_name}"
                    os.makedirs(path, exist_ok=True)
                    torch.save(m, join(path, f"{i}.pt"))
            else:
                path = f"models/{args.exp_name}"
                os.makedirs(path, exist_ok=True)
                torch.save(state_dict, join(path, "model.pt"))
            print('+')
        else:
            print()


        scheduler.step(final_score)
Пример #21
0
class TasNET_trainer(object):
    def __init__(self,
                 TasNET,
                 batch_size,
                 checkpoint="checkpoint",
                 log_folder="./log",
                 rnn_arch="LSTM",
                 optimizer='radam',
                 rerun_mode=False,
                 lr=1e-5,
                 momentum=0.9,
                 weight_decay=0,
                 num_epoches=20,
                 clip_norm=False,
                 sr=8000,
                 cudnnBenchmark=True):
        
        logger.info('---Experiment Variables---')
        logger.info('RNN Architecture: '+rnn_arch)
        logger.info('Batch Size      : '+str(batch_size))
        logger.info('Optimizer       : '+optimizer)
        logger.info('--------------------------\n')

        logger.info('Rerun mode: '+str(rerun_mode))
       
        self.TasNET = TasNET

        self.log_folder = log_folder
        self.writer = SummaryWriter(log_folder)
        self.all_log = 'all_log.log' #all log filename
        self.log('Progress Log save path: '+log_folder)

        self.log("TasNET:\n{}".format(self.TasNET))
        if type(lr) is str:
            lr = float(lr)
            logger.info("Transfrom lr from str to float => {}".format(lr))
        
        self.log('Batch size used: '+str(batch_size))

        if optimizer == 'radam':
            self.optimizer = RAdam(self.TasNET.parameters(), lr=lr, weight_decay=weight_decay)
        else:
            self.optimizer = torch.optim.Adam(
                self.TasNET.parameters(),
                lr=lr,
                weight_decay=weight_decay)
        
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                            'min', factor=0.5, patience=3,verbose=True)
        self.TasNET.to(device)

        self.checkpoint = checkpoint
        self.log('Model save path: '+checkpoint)
        
        self.num_epoches = num_epoches
        self.clip_norm = clip_norm
        self.sr = sr

        if self.clip_norm:
            self.log("Clip gradient by 2-norm {}".format(clip_norm))

        if not os.path.exists(self.checkpoint):
            os.makedirs(checkpoint)
            
        torch.backends.cudnn.benchmark=cudnnBenchmark
        self.log('cudnn benchmark status: '+str(torch.backends.cudnn.benchmark)+'\n')


    def SISNR(self, output, target):
        #output:(128,4000)
        batchsize = np.shape(output)[0]
        target = target.view(batchsize,-1)
        output = output - torch.mean(output,1,keepdim=True)
        target = target - torch.mean(target,1,keepdim=True)

        s_shat = torch.sum(output*target,1,keepdim=True)
        s_2 = torch.sum(target**2,1,keepdim=True)
        s_target = (s_shat / s_2) * target   #(128,4000)

        e_noise = output - s_target    

        return 10*torch.log10(torch.sum(e_noise**2,1,keepdim=True)\
                    /torch.sum(s_target**2,1,keepdim=True))        #(128,1)


    def loss(self,output1,output2,target1,target2):
    	#PIT loss
        loss1 = self.SISNR(output1,target1)+self.SISNR(output2,target2)
        loss2 = self.SISNR(output1,target2)+self.SISNR(output2,target1)
        min = torch.min(loss1, loss2)   #(128,1)
        return torch.mean(min)        #scale

    def train(self, dataloader, epoch):
        self.TasNET.train()
        self.log("Training...")
        tot_loss = 0
        tot_batch = len(dataloader)
        batch_indx = (epoch-1)*tot_batch

        currProcess = 0
        fivePercentProgress = tot_batch//20

        for mix_speech, speech1, speech2 in dataloader:
            self.optimizer.zero_grad()

            if torch.cuda.is_available():
                mix_speech= mix_speech.cuda()
                speech1 = speech1.cuda()
                speech2 = speech2.cuda()

            mix_speech = Variable(mix_speech)
            speech1 = Variable(speech1)
            speech2 = Variable(speech2)

            output1, output2 = self.TasNET(mix_speech)
            cur_loss = self.loss(output1,output2,speech1,speech2)
            tot_loss += cur_loss.item()
            
            #write summary
            batch_indx += 1
            self.writer.add_scalar('train_loss', cur_loss, batch_indx)
            cur_loss.backward()
            if self.clip_norm:
                nn.utils.clip_grad_norm_(self.TasNET.parameters(),
                                         self.clip_norm)
            self.optimizer.step()
            currProcess+=1
            if currProcess % fivePercentProgress == 0:
                self.log('batch {}: {:.2f}% progress ({}/{})| LR: {}'.format(batch_indx, currProcess*100/tot_batch, currProcess, tot_batch, str(self.get_curr_lr())))
                
        return tot_loss / tot_batch, tot_batch

    def validate(self, dataloader, epoch):
        """one epoch"""
        self.TasNET.eval()
        self.log("Evaluating...")
        tot_loss = 0
        tot_batch = len(dataloader)
        batch_indx = (epoch-1)*tot_batch

        currProcess = 0
        fivePercentProgress = tot_batch//20
        #print(tot_batch)

        with torch.no_grad():
            for mix_speech,speech1,speech2 in dataloader:
                if torch.cuda.is_available():
                    mix_speech = mix_speech.cuda()
                    speech1 = speech1.cuda()
                    speech2 = speech2.cuda()

                mix_speech = Variable(mix_speech)
                speech1 = Variable(speech1)
                speech2 = Variable(speech2)

                output1, output2 = self.TasNET(mix_speech)
                cur_loss = self.loss(output1,output2,speech1,speech2)
                tot_loss += cur_loss.item()
                #write summary
                batch_indx += 1
                currProcess += 1
                if currProcess % fivePercentProgress == 0:
                    self.log('batch {}: {:.2f}% progress ({}/{})| LR: {}'.format(batch_indx, currProcess*100/tot_batch, currProcess, tot_batch, str(self.get_curr_lr())))
                self.writer.add_scalar('dev_loss', cur_loss, batch_indx)
        return tot_loss / tot_batch, tot_batch

    def run(self, train_set, dev_set):
        init_loss, _ = self.validate(dev_set,1)
        self.log("Start training for {} epoches".format(self.num_epoches))
        self.log("Epoch {:2d}: dev loss ={:.4e}".format(0, init_loss))
        torch.save(self.TasNET.state_dict(), os.path.join(self.checkpoint, 'TasNET_0.pkl'))
        for epoch in range(1, self.num_epoches+1):
            train_start = time.time()
            train_loss, train_num_batch = self.train(train_set, epoch)
            valid_start = time.time()
            valid_loss, valid_num_batch = self.validate(dev_set, epoch)
            valid_end = time.time()
            self.scheduler.step(valid_loss)
            self.log(
                "Epoch {:2d}: train loss = {:.4e}({:.2f}s/{:d}) |"
                " dev loss= {:.4e}({:.2f}s/{:d})".format(
                    epoch, train_loss, valid_start - train_start,
                    train_num_batch, valid_loss, valid_end - valid_start,
                    valid_num_batch))
            save_path = os.path.join(
                self.checkpoint, "TasNET_{:d}_trainloss_{:.4e}_valloss_{:.4e}.pkl".format(
                    epoch, train_loss, valid_loss))
            torch.save(self.TasNET.state_dict(), save_path)
        self.log("Training for {} epoches done!".format(self.num_epoches))
    
    def rerun(self, train_set, dev_set, model_path, epoch_done):
        self.TasNET.load_state_dict(torch.load(model_path))
        # init_loss, _ = self.validate(dev_set,epoch_done)
        # logger.info("Start training for {} epoches".format(self.num_epoches))
        # logger.info("Epoch {:2d}: dev loss ={:.4e}".format(0, init_loss))
        # torch.save(self.TasNET.state_dict(), os.path.join(self.checkpoint, 'TasNET_0.pkl'))
        for epoch in range(epoch_done+1, self.num_epoches+1):
            train_start = time.time()
            train_loss, train_num_batch = self.train(train_set,epoch)
            valid_start = time.time()
            valid_loss, valid_num_batch = self.validate(dev_set,epoch)
            valid_end = time.time()
            self.scheduler.step(valid_loss)
            self.log(
                "Epoch {:2d}: train loss = {:.4e}({:.2f}s/{:d}) |"
                " dev loss= {:.4e}({:.2f}s/{:d})".format(
                    epoch, train_loss, valid_start - train_start,
                    train_num_batch, valid_loss, valid_end - valid_start,
                    valid_num_batch))
            save_path = os.path.join(
                self.checkpoint, "TasNET_{:d}_trainloss_{:.4e}_valloss_{:.4e}.pkl".format(
                    epoch, train_loss, valid_loss))
            torch.save(self.TasNET.state_dict(), save_path)
        self.log("Training for {} epoches done!".format(self.num_epoches))

    def get_curr_lr(self):
        for i, param_group in enumerate(self.optimizer.param_groups):
            curr_lr = float(param_group['lr'])
            return curr_lr
    
    def log(self, log_data):
        logger.info(log_data)
        try:
            f = open(self.log_folder+'/'+self.all_log,'a+')
            f.write(log_data+'\n')
            f.close()
        except:
            logger.info('failed to save last log')
Пример #22
0
        DC.train()
        
        STZ.zero_grad()
        ZTE.zero_grad()
        DC.zero_grad()

        x_train,y_train = _x.float().cuda(),_y.float().cuda()
        
        z = STZ(x_train)
        recon = ZTE(z)
        recon_loss = criterion(recon,y_train)
#         with amp.scale_loss(recon_loss, [STZ_optimizer_enc,ZTE_optimizer]) as scaled_loss:
#             scaled_loss.backward()
        recon_loss.backward()
    
        STZ_optimizer_enc.step()
        ZTE_optimizer.step()
        
        #Discriminator
        STZ.eval()
        
        STZ.zero_grad()
        ZTE.zero_grad()
        DC.zero_grad()

        z_real = EGG_prior(y_train,extract = True)
        z_fake = STZ(x_train)
        DC_real = DC(z_real)
        DC_fake = DC(z_fake)
        
        DC_loss = -torch.mean(torch.log(DC_real + EPS) + torch.log(1-DC_fake + EPS))
Пример #23
0
            optimizer.zero_grad()

            # Reset LSTM hidden state
            model.lstm.reset_hidden_state()

            # Get sequence predictions
            predictions = model(image_sequences)

            # Compute metrics
            loss = cls_criterion(predictions, labels)
            acc = 100 * (np.argmax(predictions.detach().cpu().numpy(), axis=1)
                         == labels.cpu().numpy()).mean()

            loss.backward()
            optimizer.step()

            # Keep track of epoch metrics
            epoch_metrics["loss"].append(loss.item())
            epoch_metrics["acc"].append(acc)

            # Determine approximate time left
            batches_done = epoch * len(train_dataloader) + batch_i
            batches_left = opt.num_epochs * len(
                train_dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
Пример #24
0
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    elif opt.Prediction == 'None':
        converter = TransformerConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    # model = torch.nn.DataParallel(model).to(device)
    model = model.to(device)
    model.train()
    if opt.load_from_checkpoint:
        model.load_state_dict(torch.load(os.path.join(opt.load_from_checkpoint, 'checkpoint.pth')))
        print(f'loaded checkpoint from {opt.load_from_checkpoint}...')
    elif opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.SequenceModeling == 'Transformer':
            fe_state = OrderedDict()
            state_dict = torch.load(opt.saved_model)
            for k, v in state_dict.items():
                if k.startswith('module.FeatureExtraction'):
                    new_k = re.sub('module.FeatureExtraction.', '', k)
                    fe_state[new_k] = state_dict[k]
            model.FeatureExtraction.load_state_dict(fe_state)
        else:
            if opt.FT:
                model.load_state_dict(torch.load(opt.saved_model), strict=False)
            else:
                model.load_state_dict(torch.load(opt.saved_model))
    if opt.freeze_fe:
        model.freeze(['FeatureExtraction'])
    print("Model:")
    print(model)

    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    elif opt.Prediction == 'None':
        criterion = LabelSmoothingLoss(classes=converter.n_classes, padding_idx=converter.pad_idx, smoothing=0.1)
        # criterion = torch.nn.CrossEntropyLoss(ignore_index=converter.pad_idx)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        assert opt.adam in ['Adam', 'AdamW', 'RAdam'], 'adam optimizer must be in Adam, AdamW or RAdam'
        if opt.adam == 'Adam':
            optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
        elif opt.adam == "AdamW":
            optimizer = optim.AdamW(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
        else:
            optimizer = RAdam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    if opt.load_from_checkpoint and opt.load_optimizer_state:
        optimizer.load_state_dict(torch.load(os.path.join(opt.load_from_checkpoint, 'optimizer.pth')))
        print(f'loaded optimizer state from {os.path.join(opt.load_from_checkpoint, "optimizer.pth")}')

    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    if opt.load_from_checkpoint:
        with open(os.path.join(opt.load_from_checkpoint, 'iter.json'), mode='r', encoding='utf8') as f:
            start_iter = json.load(f)
            print(f'continue to train, start_iter: {start_iter}')
            f.close()

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    # i = start_iter

    bar = tqdm(range(start_iter, opt.num_iter))
    # while(True):
    for i in bar:
        bar.set_description(f'Iter {i}: train_loss = {loss_avg.val():.5f}')
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)

            # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
            torch.backends.cudnn.enabled = True

            # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
            # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.
            # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0.
            # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707
            # cost = criterion(preds, text, preds_size, length)

        elif opt.Prediction == 'None':
            tgt_input = text['tgt_input']
            tgt_output = text['tgt_output']
            tgt_padding_mask = text['tgt_padding_mask']
            preds = model(image, tgt_input.transpose(0, 1), tgt_key_padding_mask=tgt_padding_mask,)
            cost = criterion(preds.view(-1, preds.shape[-1]), tgt_output.contiguous().view(-1))
        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if (i + 1) % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')

                # checkpoint
                os.makedirs(f'./checkpoints/{opt.experiment_name}/', exist_ok=True)

                torch.save(model.state_dict(), f'./checkpoints/{opt.experiment_name}/checkpoint.pth')
                torch.save(optimizer.state_dict(), f'./checkpoints/{opt.experiment_name}/optimizer.pth')
                with open(f'./checkpoints/{opt.experiment_name}/iter.json', mode='w', encoding='utf8') as f:
                    json.dump(i + 1, f)
                    f.close()

                with open(f'./checkpoints/{opt.experiment_name}/checkpoint.log', mode='a', encoding='utf8') as f:
                    f.write(f'Saved checkpoint with iter={i}\n')
                    f.write(f'\tCheckpoint at: ./checkpoints/{opt.experiment_name}/checkpoint.pth')
                    f.write(f'\tOptimizer at: ./checkpoints/{opt.experiment_name}/optimizer.pth')

                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (i + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        # if i == opt.num_iter:
        #     print('end the training')
        #     sys.exit()
        # i += 1
        # if i == 1: break
    print('end training')
Пример #25
0
def go(arg):

    if arg.seed < 0:
        seed = random.randint(0, 1000000)
        print('random seed: ', seed)
    else:
        torch.manual_seed(arg.seed)

    tbw = SummaryWriter(log_dir=arg.tb_dir)  # Tensorboard logging

    # load the data
    arg.path = here('data') if arg.path is None else arg.path
    data_train, data_val, data_test = read_dataset(arg.path, arg.dataset)

    # create the model
    model = GTransformer(emb=arg.embedding_size,
                         heads=arg.num_heads,
                         depth=arg.depth,
                         seq_length=arg.context,
                         num_tokens=NUM_TOKENS,
                         wide=arg.wide)

    if torch.cuda.is_available():
        model.cuda()

    print("Model parameters = %d" % sum(p.numel() for p in model.parameters()))

    if not arg.radam:
        opt = torch.optim.Adam(lr=arg.lr, params=model.parameters())
        # linear learning rate warmup
        sch = torch.optim.lr_scheduler.LambdaLR(
            opt, lambda i: min(i / (arg.lr_warmup / arg.batch_size), 1.0))
    else:
        opt = RAdam(model.parameters(), lr=arg.lr)

    if USE_APEX:
        model, opt = amp.initialize(model, opt, opt_level="O1", verbosity=0)

    best_bpb = np.inf
    best_step = 0

    # training loop
    # - note: we don't loop over the data, instead we sample a batch of random subsequences each time.
    for i in tqdm.trange(arg.num_batches):

        opt.zero_grad()

        # sample a batch of random subsequences
        starts = torch.randint(size=(arg.batch_size, ),
                               low=0,
                               high=data_train.size(0) - arg.context - 1)
        seqs_source = [
            data_train[start:start + arg.context] for start in starts
        ]
        seqs_target = [
            data_train[start + 1:start + arg.context + 1] for start in starts
        ]
        source = torch.cat([s[None, :] for s in seqs_source],
                           dim=0).to(torch.long)
        target = torch.cat([s[None, :] for s in seqs_target],
                           dim=0).to(torch.long)
        # - target is the same sequence as source, except one character ahead

        if torch.cuda.is_available():
            source, target = source.cuda(), target.cuda()
        source, target = Variable(source), Variable(target)

        output = model(source)

        loss = F.nll_loss(output.transpose(2, 1), target, reduction='mean')
        #tbw.add_scalar('transformer/train-loss', float(loss.item()) * LOG2E, i * arg.batch_size)

        if not USE_APEX:
            loss.backward()
        else:
            with amp.scale_loss(loss, opt) as scaled_loss:
                scaled_loss.backward()

        # clip gradients
        # - If the total gradient vector has a length > 1, we clip it back down to 1.
        if arg.gradient_clipping > 0.0:
            nn.utils.clip_grad_norm_(model.parameters(), arg.gradient_clipping)

        opt.step()

        if not arg.radam:
            sch.step()

        # - validate every {arg.test_every} steps. First we compute the
        #   compression on the validation (or a subset)
        #   then we generate some random text to monitor progress
        if i != 0 and (i % arg.test_every == 0 or i == arg.num_batches - 1):

            upto = arg.test_subset if arg.test_subset else data_val.size(0)
            data_sub = data_val[:upto]

            bits_per_byte = calculate_bpb(arg, model, data_sub)

            # print validation performance. 1 bit per byte is (currently) state of the art.
            print(f'epoch{i}: {bits_per_byte:.4} bits per byte')

            tag_scalar_dict = {
                'train-loss': float(loss.item()) * LOG2E,
                'eval-loss': bits_per_byte
            }
            tbw.add_scalars(f'transformer/loss', tag_scalar_dict,
                            i * arg.batch_size)

            if bits_per_byte < best_bpb:
                best_bpb = bits_per_byte
                best_step = i
                torch.save(model.state_dict(),
                           os.path.join(arg.tb_dir, 'best_model.pt'))

            print(f'best step {best_step}: {best_bpb:.4} bits per byte')

            generate_sequence(arg, model, data_val)

    # load the best model, calculate bpb of the test data and generate some random text
    finalize(arg, model, data_test)
Пример #26
0
class Trainer():
    def __init__(self, log_dir, cfg):

        self.path = log_dir
        self.cfg = cfg

        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(self.path, 'Model')
            self.log_dir = os.path.join(self.path, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.log_dir)
            self.writer = SummaryWriter(log_dir=self.log_dir)
            self.logfile = os.path.join(self.path, "logfile.log")
            sys.stdout = Logger(logfile=self.logfile)

        self.data_dir = cfg.DATASET.DATA_DIR
        self.max_epochs = cfg.TRAIN.MAX_EPOCHS
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)

        self.batch_size = cfg.TRAIN.BATCH_SIZE
        self.lr = cfg.TRAIN.LEARNING_RATE

        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        sample = cfg.SAMPLE
        self.dataset = []
        self.dataloader = []
        self.use_feats = cfg.model.use_feats
        eval_split = cfg.EVAL if cfg.EVAL else 'val'
        train_split = cfg.DATASET.train_split
        if cfg.DATASET.DATASET == 'clevr':
            clevr_collate_fn = collate_fn
            cogent = cfg.DATASET.COGENT
            if cogent:
                print(f'Using CoGenT {cogent.upper()}')

            if cfg.TRAIN.FLAG:
                self.dataset = ClevrDataset(data_dir=self.data_dir,
                                            split=train_split + cogent,
                                            sample=sample,
                                            **cfg.DATASET.params)
                self.dataloader = DataLoader(dataset=self.dataset,
                                             batch_size=cfg.TRAIN.BATCH_SIZE,
                                             shuffle=True,
                                             num_workers=cfg.WORKERS,
                                             drop_last=True,
                                             collate_fn=clevr_collate_fn)

            self.dataset_val = ClevrDataset(data_dir=self.data_dir,
                                            split=eval_split + cogent,
                                            sample=sample,
                                            **cfg.DATASET.params)
            self.dataloader_val = DataLoader(dataset=self.dataset_val,
                                             batch_size=cfg.TEST_BATCH_SIZE,
                                             drop_last=False,
                                             shuffle=False,
                                             num_workers=cfg.WORKERS,
                                             collate_fn=clevr_collate_fn)

        elif cfg.DATASET.DATASET == 'gqa':
            if self.use_feats == 'spatial':
                gqa_collate_fn = collate_fn_gqa
            elif self.use_feats == 'objects':
                gqa_collate_fn = collate_fn_gqa_objs
            if cfg.TRAIN.FLAG:
                self.dataset = GQADataset(data_dir=self.data_dir,
                                          split=train_split,
                                          sample=sample,
                                          use_feats=self.use_feats,
                                          **cfg.DATASET.params)
                self.dataloader = DataLoader(dataset=self.dataset,
                                             batch_size=cfg.TRAIN.BATCH_SIZE,
                                             shuffle=True,
                                             num_workers=cfg.WORKERS,
                                             drop_last=True,
                                             collate_fn=gqa_collate_fn)

            self.dataset_val = GQADataset(data_dir=self.data_dir,
                                          split=eval_split,
                                          sample=sample,
                                          use_feats=self.use_feats,
                                          **cfg.DATASET.params)
            self.dataloader_val = DataLoader(dataset=self.dataset_val,
                                             batch_size=cfg.TEST_BATCH_SIZE,
                                             shuffle=False,
                                             num_workers=cfg.WORKERS,
                                             drop_last=False,
                                             collate_fn=gqa_collate_fn)

        # load model
        self.vocab = load_vocab(cfg)
        self.model, self.model_ema = mac.load_MAC(cfg, self.vocab)

        self.weight_moving_average(alpha=0)
        if cfg.TRAIN.RADAM:
            self.optimizer = RAdam(self.model.parameters(), lr=self.lr)
        else:
            self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

        self.start_epoch = 0
        if cfg.resume_model:
            location = 'cuda' if cfg.CUDA else 'cpu'
            state = torch.load(cfg.resume_model, map_location=location)
            self.model.load_state_dict(state['model'])
            self.optimizer.load_state_dict(state['optim'])
            self.start_epoch = state['iter'] + 1
            state = torch.load(cfg.resume_model_ema, map_location=location)
            self.model_ema.load_state_dict(state['model'])

        if cfg.start_epoch is not None:
            self.start_epoch = cfg.start_epoch

        self.previous_best_acc = 0.0
        self.previous_best_epoch = 0
        self.previous_best_loss = 100
        self.previous_best_loss_epoch = 0

        self.total_epoch_loss = 0
        self.prior_epoch_loss = 10

        self.print_info()
        self.loss_fn = torch.nn.CrossEntropyLoss().cuda()

        self.comet_exp = Experiment(
            project_name=cfg.COMET_PROJECT_NAME,
            api_key=os.getenv('COMET_API_KEY'),
            workspace=os.getenv('COMET_WORKSPACE'),
            disabled=cfg.logcomet is False,
        )
        if cfg.logcomet:
            exp_name = cfg_to_exp_name(cfg)
            print(exp_name)
            self.comet_exp.set_name(exp_name)
            self.comet_exp.log_parameters(flatten_json_iterative_solution(cfg))
            self.comet_exp.log_asset(self.logfile)
            self.comet_exp.log_asset_data(json.dumps(cfg, indent=4),
                                          file_name='cfg.json')
            self.comet_exp.set_model_graph(str(self.model))
            if cfg.cfg_file:
                self.comet_exp.log_asset(cfg.cfg_file)

        with open(os.path.join(self.path, 'cfg.json'), 'w') as f:
            json.dump(cfg, f, indent=4)

    def print_info(self):
        print('Using config:')
        pprint.pprint(self.cfg)
        print("\n")

        pprint.pprint("Size of train dataset: {}".format(len(self.dataset)))
        # print("\n")
        pprint.pprint("Size of val dataset: {}".format(len(self.dataset_val)))
        print("\n")

        print("Using MAC-Model:")
        pprint.pprint(self.model)
        print("\n")

    def weight_moving_average(self, alpha=0.999):
        for param1, param2 in zip(self.model_ema.parameters(),
                                  self.model.parameters()):
            param1.data *= alpha
            param1.data += (1.0 - alpha) * param2.data

    def set_mode(self, mode="train"):
        if mode == "train":
            self.model.train()
            self.model_ema.train()
        else:
            self.model.eval()
            self.model_ema.eval()

    def reduce_lr(self):
        epoch_loss = self.total_epoch_loss  # / float(len(self.dataset) // self.batch_size)
        lossDiff = self.prior_epoch_loss - epoch_loss
        if ((lossDiff < 0.015 and self.prior_epoch_loss < 0.5 and self.lr > 0.00002) or \
            (lossDiff < 0.008 and self.prior_epoch_loss < 0.15 and self.lr > 0.00001) or \
            (lossDiff < 0.003 and self.prior_epoch_loss < 0.10 and self.lr > 0.000005)):
            self.lr *= 0.5
            print("Reduced learning rate to {}".format(self.lr))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.lr
        self.prior_epoch_loss = epoch_loss
        self.total_epoch_loss = 0

    def save_models(self, iteration):
        save_model(self.model,
                   self.optimizer,
                   iteration,
                   self.model_dir,
                   model_name="model")
        save_model(self.model_ema,
                   None,
                   iteration,
                   self.model_dir,
                   model_name="model_ema")

    def train_epoch(self, epoch):
        cfg = self.cfg
        total_loss = 0.
        total_correct = 0
        total_samples = 0

        self.labeled_data = iter(self.dataloader)
        self.set_mode("train")

        dataset = tqdm(self.labeled_data, total=len(self.dataloader), ncols=20)

        for data in dataset:
            ######################################################
            # (1) Prepare training data
            ######################################################
            image, question, question_len, answer = data['image'], data[
                'question'], data['question_length'], data['answer']
            answer = answer.long()
            question = Variable(question)
            answer = Variable(answer)

            if cfg.CUDA:
                if self.use_feats == 'spatial':
                    image = image.cuda()
                elif self.use_feats == 'objects':
                    image = [e.cuda() for e in image]
                question = question.cuda()
                answer = answer.cuda().squeeze()
            else:
                question = question
                image = image
                answer = answer.squeeze()

            ############################
            # (2) Train Model
            ############################
            self.optimizer.zero_grad()

            scores = self.model(image, question, question_len)
            loss = self.loss_fn(scores, answer)
            loss.backward()

            if self.cfg.TRAIN.CLIP_GRADS:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.cfg.TRAIN.CLIP)

            self.optimizer.step()
            self.weight_moving_average()

            ############################
            # (3) Log Progress
            ############################
            correct = scores.detach().argmax(1) == answer
            total_correct += correct.sum().cpu().item()
            total_loss += loss.item() * answer.size(0)
            total_samples += answer.size(0)

            avg_loss = total_loss / total_samples
            train_accuracy = total_correct / total_samples
            # accuracy = correct.sum().cpu().numpy() / answer.shape[0]

            # if avg_loss == 0:
            #     avg_loss = loss.item()
            #     train_accuracy = accuracy
            # else:
            #     avg_loss = 0.99 * avg_loss + 0.01 * loss.item()
            #     train_accuracy = 0.99 * train_accuracy + 0.01 * accuracy
            # self.total_epoch_loss += loss.item() * answer.size(0)

            dataset.set_description(
                'Epoch: {}; Avg Loss: {:.5f}; Avg Train Acc: {:.5f}'.format(
                    epoch + 1, avg_loss, train_accuracy))

        self.total_epoch_loss = avg_loss

        dict = {
            "loss": avg_loss,
            "accuracy": train_accuracy,
            "avg_loss": avg_loss,  # For commet
            "avg_accuracy": train_accuracy,  # For commet
        }
        return dict

    def train(self):
        cfg = self.cfg
        print("Start Training")
        for epoch in range(self.start_epoch, self.max_epochs):

            with self.comet_exp.train():
                dict = self.train_epoch(epoch)
                self.reduce_lr()
                dict['epoch'] = epoch + 1
                dict['lr'] = self.lr
                self.comet_exp.log_metrics(
                    dict,
                    epoch=epoch + 1,
                )

            with self.comet_exp.validate():
                dict = self.log_results(epoch, dict)
                dict['epoch'] = epoch + 1
                dict['lr'] = self.lr
                self.comet_exp.log_metrics(
                    dict,
                    epoch=epoch + 1,
                )

            if cfg.TRAIN.EALRY_STOPPING:
                if epoch - cfg.TRAIN.PATIENCE == self.previous_best_epoch:
                    # if epoch - cfg.TRAIN.PATIENCE == self.previous_best_loss_epoch:
                    print('Early stop')
                    break

        self.comet_exp.log_asset(self.logfile)
        self.save_models(self.max_epochs)
        self.writer.close()
        print("Finished Training")
        print(
            f"Highest validation accuracy: {self.previous_best_acc} at epoch {self.previous_best_epoch}"
        )

    def log_results(self, epoch, dict, max_eval_samples=None):
        epoch += 1
        self.writer.add_scalar("avg_loss", dict["loss"], epoch)
        self.writer.add_scalar("train_accuracy", dict["accuracy"], epoch)

        metrics = self.calc_accuracy("validation",
                                     max_samples=max_eval_samples)
        self.writer.add_scalar("val_accuracy_ema", metrics['acc_ema'], epoch)
        self.writer.add_scalar("val_accuracy", metrics['acc'], epoch)
        self.writer.add_scalar("val_loss_ema", metrics['loss_ema'], epoch)
        self.writer.add_scalar("val_loss", metrics['loss'], epoch)

        print(
            "Epoch: {epoch}\tVal Acc: {acc},\tVal Acc EMA: {acc_ema},\tAvg Loss: {loss},\tAvg Loss EMA: {loss_ema},\tLR: {lr}"
            .format(epoch=epoch, lr=self.lr, **metrics))

        if metrics['acc'] > self.previous_best_acc:
            self.previous_best_acc = metrics['acc']
            self.previous_best_epoch = epoch
        if metrics['loss'] < self.previous_best_loss:
            self.previous_best_loss = metrics['loss']
            self.previous_best_loss_epoch = epoch

        if epoch % self.snapshot_interval == 0:
            self.save_models(epoch)

        return metrics

    def calc_accuracy(self, mode="train", max_samples=None):
        self.set_mode("validation")

        if mode == "train":
            loader = self.dataloader
        # elif (mode == "validation") or (mode == 'test'):
        #     loader = self.dataloader_val
        else:
            loader = self.dataloader_val

        total_correct = 0
        total_correct_ema = 0
        total_samples = 0
        total_loss = 0.
        total_loss_ema = 0.
        pbar = tqdm(loader, total=len(loader), desc=mode.upper(), ncols=20)
        for data in pbar:

            image, question, question_len, answer = data['image'], data[
                'question'], data['question_length'], data['answer']
            answer = answer.long()
            question = Variable(question)
            answer = Variable(answer)

            if self.cfg.CUDA:
                if self.use_feats == 'spatial':
                    image = image.cuda()
                elif self.use_feats == 'objects':
                    image = [e.cuda() for e in image]
                question = question.cuda()
                answer = answer.cuda().squeeze()

            with torch.no_grad():
                scores = self.model(image, question, question_len)
                scores_ema = self.model_ema(image, question, question_len)

                loss = self.loss_fn(scores, answer)
                loss_ema = self.loss_fn(scores_ema, answer)

            correct = scores.detach().argmax(1) == answer
            correct_ema = scores_ema.detach().argmax(1) == answer

            total_correct += correct.sum().cpu().item()
            total_correct_ema += correct_ema.sum().cpu().item()

            total_loss += loss.item() * answer.size(0)
            total_loss_ema += loss_ema.item() * answer.size(0)

            total_samples += answer.size(0)

            avg_acc = total_correct / total_samples
            avg_acc_ema = total_correct_ema / total_samples
            avg_loss = total_loss / total_samples
            avg_loss_ema = total_loss_ema / total_samples

            pbar.set_postfix({
                'Acc': f'{avg_acc:.5f}',
                'Acc Ema': f'{avg_acc_ema:.5f}',
                'Loss': f'{avg_loss:.5f}',
                'Loss Ema': f'{avg_loss_ema:.5f}',
            })

        return dict(acc=avg_acc,
                    acc_ema=avg_acc_ema,
                    loss=avg_loss,
                    loss_ema=avg_loss_ema)
Пример #27
0
def train(args):
    model, model_file = create_model(args.encoder_type,
                                     work_dir=args.work_dir,
                                     ckp=args.ckp)
    model = model.cuda()

    loaders = get_train_val_loaders(batch_size=args.batch_size)

    #optimizer = RAdam([
    #    {'params': model.decoder.parameters(), 'lr': args.lr},
    #    {'params': model.encoder.parameters(), 'lr': args.lr / 10.},
    #])
    if args.optim_name == 'RAdam':
        optimizer = RAdam(model.parameters(), lr=args.lr)
    elif args.optim_name == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    elif args.optim_name == 'SGD':
        optimizer = optim.SGD(model.parameters(), momentum=0.9, lr=args.lr)

    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1",verbosity=0)

    if torch.cuda.device_count() > 1:
        model = DataParallel(model)

    if args.lrs == 'plateau':
        lr_scheduler = ReduceLROnPlateau(optimizer,
                                         mode='max',
                                         factor=args.factor,
                                         patience=args.patience,
                                         min_lr=args.min_lr)
    else:
        lr_scheduler = CosineAnnealingLR(optimizer,
                                         args.t_max,
                                         eta_min=args.min_lr)

    best_metrics = 0.
    best_key = 'dice'

    print(
        'epoch |    lr    |      %        |  loss  |  avg   |   loss |  dice  |  best  | time |  save |'
    )

    if not args.no_first_val:
        val_metrics = validate(args, model, loaders['valid'])
        print(
            'val   |          |               |        |        | {:.4f} | {:.4f} | {:.4f} |        |        |'
            .format(val_metrics['loss'], val_metrics['dice'],
                    val_metrics['dice']))

        best_metrics = val_metrics[best_key]

    if args.val:
        return

    model.train()

    #if args.lrs == 'plateau':
    #    lr_scheduler.step(best_metrics)
    #else:
    #    lr_scheduler.step()
    train_iter = 0

    for epoch in range(args.num_epochs):
        train_loss = 0

        current_lr = get_lrs(optimizer)
        bg = time.time()
        for batch_idx, data in enumerate(loaders['train']):
            train_iter += 1
            img, targets = data[0].cuda(), data[1].cuda()
            batch_size = img.size(0)

            outputs = model(img)
            loss = _reduce_loss(criterion(outputs, targets))
            (loss).backward()

            #with amp.scale_loss(loss*batch_size, optimizer) as scaled_loss:
            #    scaled_loss.backward()

            if batch_idx % 4 == 0:
                optimizer.step()
                optimizer.zero_grad()

            train_loss += loss.item()
            print('\r {:4d} | {:.6f} | {:06d}/{} | {:.4f} | {:.4f} |'.format(
                epoch, float(current_lr[0]),
                args.batch_size * (batch_idx + 1), loaders['train'].num,
                loss.item(), train_loss / (batch_idx + 1)),
                  end='')

            if train_iter > 0 and train_iter % args.iter_val == 0:
                save_model(model, model_file + '_latest')
                val_metrics = validate(args, model, loaders['valid'])

                _save_ckp = ''
                if val_metrics[best_key] > best_metrics:
                    best_metrics = val_metrics[best_key]
                    save_model(model, model_file)
                    _save_ckp = '*'
                print(' {:.4f} | {:.4f} | {:.4f} | {:.2f} |  {:4s} |'.format(
                    val_metrics['loss'], val_metrics['dice'], best_metrics,
                    (time.time() - bg) / 60, _save_ckp))

                model.train()

                if args.lrs == 'plateau':
                    lr_scheduler.step(best_metrics)
                else:
                    lr_scheduler.step()
                current_lr = get_lrs(optimizer)
Пример #28
0
def train_ccblock(model_options):
    # get train&valid datasets' paths
    if model_options.trainset_num > 1:
        train_file_paths = [
            model_options.trainset_path.format(i)
            for i in range(1, model_options.trainset_num + 1)
        ]
    else:
        train_file_paths = [model_options.trainset_path]

    # load datasets
    print(train_file_paths)
    label_paths = "/home/langruimin/BLSTM_pytorch/data/fcv/fcv_train_labels.mat"
    videoset = VideoDataset(train_file_paths, label_paths)
    print(len(videoset))

    # create model
    model = RCCAModule(1, 1)
    model_quan = Quantization(model_options.subLevel, model_options.subCenters,
                              model_options.dim)

    params_path = os.path.join(model_options.model_save_path,
                               model_options.params_filename)
    params_path_Q = os.path.join(model_options.model_save_path,
                                 model_options.Qparams_filename)
    if model_options.reload_params:
        print('Loading model params...')
        model.load_state_dict(torch.load(params_path))
        print('Done.')

    model = model.cuda()
    model_quan = model_quan.cuda()
    # optimizer
    optimizer = RAdam(model.parameters(),
                      lr=1e-3,
                      betas=(0.9, 0.999),
                      weight_decay=1e-4)
    optimizer2 = RAdam(
        model_quan.parameters(),
        lr=1e-3,  # 7e-6
        betas=(0.9, 0.999),
        weight_decay=1e-4)

    lr_C = ""
    lr_Q = ""
    # milestones = []
    # lr_schduler_C = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)
    # lr_schduler_Q = torch.optim.lr_scheduler.MultiStepLR(optimizer2, milestones, gamma=0.6, last_epoch=-1)

    selector = AllTripletSelector()
    triplet_loss = OnlineTripletLoss(margin=512, triplet_selector=selector)

    batch_idx = 1
    train_loss_rec = open(
        os.path.join(model_options.records_save_path,
                     model_options.train_loss_filename), 'w')
    error_ = 0.
    loss_ = 0.
    num = 0
    print("##########start train############")
    trainloader = torch.utils.data.DataLoader(videoset,
                                              batch_size=9,
                                              shuffle=True,
                                              num_workers=4,
                                              pin_memory=True)
    model.train()
    model_quan.train()

    init_train_label = np.load(
        "/home/langruimin/BLSTM_pytorch/data/fcv/init_train_labels.npy")

    for l in range(100):
        # lr_schduler_C.step(l)
        # milestones.append(l+2)
        # lr_schduler_Q.step(l)

        # training
        for i, (data, index, _, _) in enumerate(trainloader):
            data = data.to(model_options.default_dtype)
            data = data.unsqueeze(1)
            data = data.cuda()
            # cc_block
            output_ccblock_mean = torch.tanh(model(data))

            # quantization block
            Qhard, Qsoft, SoftDistortion, HardDistortion, JointCenter, error, _ = model_quan(
                output_ccblock_mean)
            Q_loss = 0.1 * SoftDistortion + HardDistortion + 0.1 * JointCenter

            tri_loss, tri_num = triplet_loss(output_ccblock_mean,
                                             init_train_label[index])

            optimizer2.zero_grad()
            Q_loss.backward(retain_graph=True)
            optimizer2.step()

            optimizer.zero_grad()
            tri_loss.backward()
            optimizer.step()

            error_ += error.item()
            loss_ += tri_loss.item()
            num += 1
            if batch_idx % model_options.disp_freq == 0:
                info = "epoch{0} Batch {1} loss:{2:.3f}  distortion:{3:.3f} " \
                    .format(l, batch_idx, loss_/ num, error_ / num)
                print(info)
                train_loss_rec.write(info + '\n')

            batch_idx += 1
        batch_idx = 0
        error_ = 0.
        loss_ = 0.
        num = 0

        if (l + 1) % model_options.save_freq == 0:
            print('epoch: ', l, 'New best model. Saving model ...')
            torch.save(model.state_dict(), params_path)
            torch.save(model_quan.state_dict(), params_path_Q)

            for param_group in optimizer.param_groups:
                lr_C = param_group['lr']
            for param_group in optimizer2.param_groups:
                lr_Q = param_group['lr']
            record_inf = "saved model at epoch {0} lr_C:{1} lr_Q:{2}".format(
                l, lr_C, lr_Q)
            train_loss_rec.write(record_inf + '\n')
        print("##########epoch done##########")

    print('train done. Saving model ...')
    torch.save(model.state_dict(), params_path)
    torch.save(model_quan.state_dict(), params_path_Q)
    print("##########train done##########")
Пример #29
0
class Trainer:
    def __init__(self,
                 model,
                 train_loader,
                 test_loader,
                 epochs=200,
                 batch_size=60,
                 run_id=0,
                 logs_dir='logs',
                 device='cpu',
                 saturation_device=None,
                 optimizer='None',
                 plot=True,
                 compute_top_k=False,
                 data_prallel=False,
                 conv_method='channelwise',
                 thresh=.99):
        self.saturation_device = device if saturation_device is None else saturation_device
        self.device = device
        self.model = model
        self.epochs = epochs
        self.plot = plot
        self.compute_top_k = compute_top_k

        if 'cuda' in device:
            cudnn.benchmark = True

        self.train_loader = train_loader
        self.test_loader = test_loader

        self.criterion = nn.CrossEntropyLoss()
        print('Checking for optimizer for {}'.format(optimizer))
        #optimizer = str(optimizer)
        if optimizer == "adam":
            print('Using adam')
            self.optimizer = optim.Adam(model.parameters())
        elif optimizer == 'bad_lr_adam':
            print('Using adam with to large learning rate')
            self.optimizer = optim.Adam(model.parameters(), lr=0.01)
        elif optimizer == "SGD":
            print('Using SGD')
            self.optimizer = optim.SGD(model.parameters(),
                                       lr=0.5,
                                       momentum=0.9)
        elif optimizer == "LRS":
            print('Using LRS')
            self.optimizer = optim.SGD(model.parameters(),
                                       lr=0.01,
                                       momentum=0.9)
            self.lr_scheduler = optim.lr_scheduler.StepLR(self.optimizer, 5)
        elif optimizer == "radam":
            print('Using radam')
            self.optimizer = RAdam(model.parameters())
        else:
            raise ValueError('Unknown optimizer {}'.format(optimizer))
        self.opt_name = optimizer
        save_dir = os.path.join(logs_dir, model.name, train_loader.name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        self.savepath = os.path.join(
            save_dir,
            f'{model.name}_bs{batch_size}_e{epochs}_t{int(thresh*1000)}_id{run_id}.csv'
        )
        self.experiment_done = False
        if os.path.exists(self.savepath):
            trained_epochs = len(pd.read_csv(self.savepath, sep=';'))

            if trained_epochs >= epochs:
                self.experiment_done = True
                print(
                    f'Experiment Logs for the exact same experiment with identical run_id was detecting, training will be skipped, consider using another run_id'
                )
        self.parallel = data_prallel
        if data_prallel:
            self.model = nn.DataParallel(self.model, ['cuda:0', 'cuda:1'])
        writer = CSVandPlottingWriter(self.savepath.replace('.csv', ''),
                                      fontsize=16,
                                      primary_metric='test_accuracy')
        self.pooling_strat = conv_method
        print('Settomg Satiraton recording threshold to', thresh)
        self.stats = CheckLayerSat(self.savepath.replace('.csv', ''),
                                   writer,
                                   model,
                                   ignore_layer_names='convolution',
                                   stats=['lsat'],
                                   sat_threshold=.99,
                                   verbose=False,
                                   conv_method=conv_method,
                                   log_interval=1,
                                   device=self.saturation_device,
                                   reset_covariance=True,
                                   max_samples=None)

    def train(self):
        if self.experiment_done:
            return
        self.model.to(self.device)
        for epoch in range(self.epochs):
            print(
                "{} Epoch {}, training loss: {}, training accuracy: {}".format(
                    now(), epoch, *self.train_epoch()))
            self.test()
            if self.opt_name == "LRS":
                print('LRS step')
                self.lr_scheduler.step()
            self.stats.add_saturations()
            #self.stats.save()
            #if self.plot:
            #    plot_saturation_level_from_results(self.savepath, epoch)
        self.stats.close()
        return self.savepath + '.csv'

    def train_epoch(self):
        self.model.train()
        correct = 0
        total = 0
        running_loss = 0
        old_time = time()
        top5_accumulator = 0
        for batch, data in enumerate(self.train_loader):
            if batch % 10 == 0 and batch != 0:
                print(
                    batch, 'of', len(self.train_loader), 'processing time',
                    time() - old_time,
                    "top5_acc:" if self.compute_top_k else 'acc:',
                    round(top5_accumulator /
                          (batch), 3) if self.compute_top_k else correct /
                    total)
                old_time = time()
            inputs, labels = data
            inputs, labels = inputs.to(self.device), labels.to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            if self.compute_top_k:
                top5_accumulator += accuracy(outputs, labels, (5, ))[0]
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)

            correct += (predicted == labels).sum().item()
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
        self.stats.add_scalar('training_loss', running_loss / total)
        if self.compute_top_k:
            self.stats.add_scalar('training_accuracy',
                                  (top5_accumulator / (batch + 1)))
        else:
            self.stats.add_scalar('training_accuracy', correct / total)
        return running_loss / total, correct / total

    def test(self, save=True):
        self.model.eval()
        correct = 0
        total = 0
        test_loss = 0
        top5_accumulator = 0
        with torch.no_grad():
            for batch, data in enumerate(self.test_loader):
                inputs, labels = data
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                if self.compute_top_k:
                    top5_accumulator += accuracy(outputs, labels, (5, ))[0]
                test_loss += loss.item()

        self.stats.add_scalar('test_loss', test_loss / total)
        if self.compute_top_k:
            self.stats.add_scalar('test_accuracy',
                                  top5_accumulator / (batch + 1))
            print('{} Test Top5-Accuracy on {} images: {:.4f}'.format(
                now(), total, top5_accumulator / (batch + 1)))

        else:
            self.stats.add_scalar('test_accuracy', correct / total)
            print('{} Test Accuracy on {} images: {:.4f}'.format(
                now(), total, correct / total))
        if save:
            torch.save({'model_state_dict': self.model.state_dict()},
                       self.savepath.replace('.csv', '.pt'))
        return correct / total, test_loss / total
class Trainer(object):
    '''This class takes care of training and validation of our model'''
    def __init__(self, model):
        self.fold = args.fold
        self.total_folds = 5
        self.num_workers = 6
        self.batch_size = {
            "train": args.batch_size,
            "val": args.batch_size
        }  # 4
        self.accumulation_steps = 32 // self.batch_size['train']
        self.lr = args.learning_rate
        self.num_epochs = args.epochs
        self.best_loss = float("inf")
        self.best_dice = 0
        self.phases = ["train", "val"]
        self.device = torch.device("cuda:0")
        torch.set_default_tensor_type("torch.cuda.FloatTensor")
        self.net = model
        self.criterion = MixedLoss(10.0, 2.0)

        if args.swa is True:
            # base_opt = torch.optim.SGD(self.net.parameters(), lr=args.max_lr, momentum=args.momentum, weight_decay=args.weight_decay)
            base_opt = RAdam(self.net.parameters(), lr=self.lr)
            self.optimizer = SWA(base_opt,
                                 swa_start=38,
                                 swa_freq=1,
                                 swa_lr=args.min_lr)
            # self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, scheduler_step, args.min_lr)
        else:
            if args.optimizer.lower() == 'adam':
                self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
            elif args.optimizer.lower() == 'radam':
                self.optimizer = RAdam(
                    self.net.parameters(), lr=self.lr
                )  # betas=(args.beta1, args.beta2),weight_decay=args.weight_decay
            elif args.optimizer.lower() == 'sgd':
                self.optimizer = torch.optim.SGD(
                    self.net.parameters(),
                    lr=args.max_lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)

        if args.scheduler.lower() == 'reducelronplateau':
            self.scheduler = ReduceLROnPlateau(self.optimizer,
                                               mode="min",
                                               patience=args.patience,
                                               verbose=True)
        elif args.scheduler.lower() == 'clr':
            self.scheduler = CyclicLR(self.optimizer,
                                      base_lr=self.lr,
                                      max_lr=args.max_lr)
        self.net = self.net.to(self.device)
        cudnn.benchmark = True
        self.dataloaders = {
            phase: provider(
                fold=args.fold,
                total_folds=5,
                data_folder=data_folder,
                df_path=train_rle_path,
                phase=phase,
                size=args.img_size_target,
                mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225),
                batch_size=self.batch_size[phase],
                num_workers=self.num_workers,
            )
            for phase in self.phases
        }
        self.losses = {phase: [] for phase in self.phases}
        self.iou_scores = {phase: [] for phase in self.phases}
        self.dice_scores = {phase: [] for phase in self.phases}
        self.kaggle_metric = {phase: [] for phase in self.phases}

    def forward(self, images, targets):
        images = images.to(self.device)
        masks = targets.to(self.device)
        outputs = self.net(images)
        loss = self.criterion(
            outputs, masks
        )  # weighted_lovasz  # lovasz_hinge(outputs, masks) # self.criterion(outputs, masks)
        return loss, outputs

    def iterate(self, epoch, phase):
        meter = Meter(phase, epoch)
        start = time.strftime("%H:%M:%S")
        print(f"Starting epoch: {epoch} | phase: {phase} | ⏰: {start}")
        batch_size = self.batch_size[phase]
        start = time.time()
        self.net.train(phase == "train")
        dataloader = self.dataloaders[phase]
        running_loss = 0.0
        total_batches = len(dataloader)
        tk0 = tqdm(dataloader, total=total_batches)
        self.optimizer.zero_grad()
        for itr, batch in enumerate(tk0):
            images, targets = batch
            loss, outputs = self.forward(images, targets)
            loss = loss / self.accumulation_steps
            if phase == "train":
                loss.backward()
                if (itr + 1) % self.accumulation_steps == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
            running_loss += loss.item()
            outputs = outputs.detach().cpu()
            meter.update(targets, outputs)
            tk0.set_postfix(loss=(running_loss / ((itr + 1))))
        if args.swa is True:
            self.optimizer.swap_swa_sgd()
        epoch_loss = (running_loss * self.accumulation_steps
                      ) / total_batches  # running_loss / total_batches
        dice, iou, scores, kaggle_metric = epoch_log(phase, epoch, epoch_loss,
                                                     meter,
                                                     start)  # kaggle_metric
        write_event(log, dice, loss=epoch_loss)
        self.losses[phase].append(epoch_loss)
        self.dice_scores[phase].append(dice)
        self.iou_scores[phase].append(iou)
        self.kaggle_metric[phase].append(kaggle_metric)
        torch.cuda.empty_cache()
        return epoch_loss, dice, iou, scores, kaggle_metric  # kaggle_metric

    def start(self):
        if os.path.exists(args.log_path + 'v' + str(args.version) + '/' +
                          str(args.fold)):
            shutil.rmtree(args.log_path + 'v' + str(args.version) + '/' +
                          str(args.fold))
        else:
            os.makedirs(args.log_path + 'v' + str(args.version) + '/' +
                        str(args.fold))
        writer = SummaryWriter(args.log_path + 'v' + str(args.version) + '/' +
                               str(args.fold))

        num_snapshot = 0
        best_acc = 0
        model_path = args.weights_path + 'v' + str(
            args.version) + '/' + save_model_name
        if os.path.exists(model_path):
            state = torch.load(model_path,
                               map_location=lambda storage, loc: storage)
            model.load_state_dict(state["state_dict"])  # ["state_dict"]
            epoch = state['epoch']
            self.best_loss = state['best_loss']
            self.best_dice = state['best_dice']
            state['state_dict'] = state['state_dict']
            state['optimizer'] = state['optimizer']
        else:
            epoch = 1
            self.best_loss = float('inf')
            self.best_dice = 0

        for epoch in range(epoch, self.num_epochs + 1):
            print('-' * 30, 'Epoch:', epoch, '-' * 30)
            train_loss, train_dice, train_iou, train_scores, train_kaggle_metric = self.iterate(
                epoch, "train")  # train_kaggle_metric
            state = {
                "epoch": epoch,
                "best_loss": self.best_loss,
                "best_dice": self.best_dice,
                "state_dict": self.net.state_dict(),
                "optimizer": self.optimizer.state_dict(),
            }
            try:
                val_loss, val_dice, val_iou, val_scores, val_kaggle_metric = self.iterate(
                    epoch, "val")  # val_kaggle_metric
                self.scheduler.step(val_loss)
                if val_loss < self.best_loss:
                    print("******** New optimal found, saving state ********")
                    state["best_loss"] = self.best_loss = val_loss
                    torch.save(state, model_path)
                    try:
                        scores = val_scores
                    except:
                        scores = 'None'
                if val_dice > self.best_dice:
                    print("******** Best Dice Score, saving state ********")
                    state["best_dice"] = self.best_dice = val_dice
                    best_dice__path = args.weights_path + 'v' + str(
                        args.version
                    ) + '/' + 'best_dice_' + basic_name + '.pth'
                    torch.save(state, best_dice__path)
                # if val_dice > best_acc:
                #     print("******** New optimal found, saving state ********")
                #     # state["best_acc"] = self.best_acc = val_dice
                #     best_acc = val_dice
                #     best_param = self.net.state_dict()

                # if (epoch + 1) % scheduler_step == 0:
                #     # torch.save(best_param, args.save_weight + args.weight_name + str(idx) + str(num_snapshot) + '.pth')
                #     save_model_name = basic_name + '.pth' # '_' +str(num_snapshot)
                #     torch.save(best_param, args.weights_path + 'v' + str(args.version) + '/' + save_model_name)
                #     # state
                #     try:
                #         scores = val_scores
                #     except:
                #         scores = 'None'
                #     optimizer = torch.optim.SGD(self.net.parameters(), lr=args.max_lr, momentum=args.momentum,
                #                                 weight_decay=args.weight_decay)
                #     self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, scheduler_step, args.min_lr)
                #     num_snapshot += 1
                #     best_acc = 0
                writer.add_scalars('loss', {
                    'train': train_loss,
                    'val': val_loss
                }, epoch)
                writer.add_scalars('dice_score', {
                    'train': train_dice,
                    'val': val_dice
                }, epoch)
                writer.add_scalars('IoU', {
                    'train': train_iou,
                    'val': val_iou
                }, epoch)
                writer.add_scalars('New_Dice', {
                    'train': train_kaggle_metric,
                    'val': val_kaggle_metric
                }, epoch)
            except KeyboardInterrupt:
                print('Ctrl+C, saving snapshot')
                torch.save(
                    state, args.weights_path + 'v' + str(args.version) + '/' +
                    save_model_name)
                print('done.')
            # writer.add_scalars('Accuracy', {'train': train_kaggle_metric, 'val': val_kaggle_metric}, epoch)

        # writer.export_scalars_to_json(args.log_path + 'v' + str(args.version) + '/' + basic_name + '.json')
        writer.close()
        return scores